diff --git a/.github/workflows/lint.yml b/.github/workflows/lint.yml new file mode 100644 index 00000000..74e0155f --- /dev/null +++ b/.github/workflows/lint.yml @@ -0,0 +1,44 @@ +name: Lint Python scripts + +on: + push: + paths: + - '**.py' + pull_request: + paths: + - '**.py' + +jobs: + lint: + name: Lint Python scripts + runs-on: ubuntu-latest + strategy: + fail-fast: false + matrix: + python-version: ['3.11', '3.12'] + steps: + - uses: actions/checkout@v4 + with: + persist-credentials: false + + - uses: actions/setup-python@v5 + with: + python-version: ${{ matrix.python-version }} + + - name: Cache pip + uses: actions/cache@v4 + id: cache-pip + with: + path: ~/.cache/pip + key: pip_cache_py_${{ matrix.python-version }}_${{ hashFiles('**/pyproject.toml') }} + + - name: Install linters + run: | + python -m pip install --upgrade pip + pip install isort flake8 + + - name: Run isort + run: isort --check-only --diff . + + - name: Run flake8 + run: flake8 . --max-line-length=120 diff --git a/examples/tutorials/cbioportal.ipynb b/examples/tutorials/cbioportal.ipynb index a800b030..f489ff80 100644 --- a/examples/tutorials/cbioportal.ipynb +++ b/examples/tutorials/cbioportal.ipynb @@ -13,7 +13,7 @@ "import numpy as np\n", "import pandas as pd\n", "\n", - "STUDY_ID = 'pancan_pcawg_2020'" + "STUDY_ID = 'coad_cptac_2019'" ] }, { @@ -37,6 +37,24 @@ "cbio.get_cbioportal_data(cbio.study_id)" ] }, + { + "cell_type": "code", + "execution_count": null, + "id": "b08670ec", + "metadata": {}, + "outputs": [], + "source": [ + "# Check the specific files the notebook needs are present\n", + "required_files = [\"data_cna.txt\", \"data_mutations.txt\", \"data_clinical_sample.txt\"]\n", + "missing = [f for f in required_files if not os.path.exists(os.path.join(STUDY_ID, f))]\n", + "assert not missing, f\"Missing required files in {STUDY_ID}: {missing}\"\n", + "\n", + "print(f\"All required files present in ./{STUDY_ID}/:\")\n", + "for f in required_files:\n", + " size = os.path.getsize(os.path.join(STUDY_ID, f))\n", + " print(f\" {f} ({size:,} bytes)\")" + ] + }, { "cell_type": "code", "execution_count": null, @@ -144,7 +162,7 @@ "metadata": {}, "outputs": [], "source": [ - "counts = np.unique(list(data['clin'].CANCER_TYPE), return_counts=True) " + "counts = np.unique(list(data['clin'].CANCER_TYPE), return_counts=True)" ] }, { @@ -274,8 +292,8 @@ "metadata": {}, "outputs": [], "source": [ - "data_importer = flexynesis.data.DataImporter(path=f'{STUDY_ID}/', \n", - " data_types = ['mut', 'cna'], \n", + "data_importer = flexynesis.data.DataImporter(path=f'{STUDY_ID}/',\n", + " data_types = ['mut', 'cna'],\n", " concatenate=False, top_percentile=10, variance_threshold=0.8,\n", " min_features=500)\n", "train_dataset, test_dataset = data_importer.import_data()" @@ -324,7 +342,7 @@ ], "metadata": { "kernelspec": { - "display_name": "Python 3", + "display_name": "flexy", "language": "python", "name": "python3" }, @@ -338,7 +356,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.11.0" + "version": "3.14.4" } }, "nbformat": 4, diff --git a/flexynesis/__init__.py b/flexynesis/__init__.py index eba3fd9d..3a31a0d2 100644 --- a/flexynesis/__init__.py +++ b/flexynesis/__init__.py @@ -1,9 +1,10 @@ # Lazy imports to avoid slow startup # Only import essential components that are needed for basic functionality +from importlib.metadata import PackageNotFoundError, version + # Import core modules without heavy dependencies from .config import search_spaces -from importlib.metadata import PackageNotFoundError, version # Use the distribution name as published on PyPI (may differ from import name in some projects) _DISTRIBUTION_NAME = "flexynesis" @@ -13,52 +14,60 @@ except PackageNotFoundError: # Happens in some dev scenarios before the package is installed __version__ = "0+unknown" - + + class LazyModule: """Lazy module that only imports when accessed.""" - + def __init__(self, module_name): self._module_name = module_name self._module = None self._import_error = None - + def _import_module(self): if self._module is None and self._import_error is None: try: import importlib - self._module = importlib.import_module(f'.{self._module_name}', package=__name__) + + self._module = importlib.import_module( + f".{self._module_name}", package=__name__ + ) except ImportError as e: self._import_error = e - raise ImportError(f"Failed to import {self._module_name} module: {e}. " - f"This usually means some dependencies are missing. " - f"Try installing required packages or check your environment.") + raise ImportError( + f"Failed to import {self._module_name} module: {e}. " + f"This usually means some dependencies are missing. " + f"Try installing required packages or check your environment." + ) elif self._import_error is not None: raise self._import_error return self._module - + def __getattr__(self, name): module = self._import_module() return getattr(module, name) - + def __dir__(self): if self._module is None: return [] module = self._import_module() return dir(module) - + def __repr__(self): if self._module is None: return f"" else: return f"" + # Create lazy module proxies - these are NOT imported yet -modules = LazyModule('modules') -data = LazyModule('data') -main = LazyModule('main') -models = LazyModule('models') -feature_selection = LazyModule('feature_selection') -utils = LazyModule('utils') +modules = LazyModule("modules") +data = LazyModule("data") +main = LazyModule("main") +models = LazyModule("models") +feature_selection = LazyModule("feature_selection") +utils = LazyModule("utils") + # Import commonly used classes directly for easy access # These will be imported lazily when first accessed @@ -66,19 +75,20 @@ def _get_data_importer(): """Lazy getter for DataImporter class.""" return data.DataImporter + def _get_models(): """Lazy getter for model classes.""" return models + # Export all modules and commonly used classes __all__ = [ "search_spaces", "modules", - "data", + "data", "main", "models", "feature_selection", "utils", - "DataImporter" + "DataImporter", ] - diff --git a/flexynesis/__main__.py b/flexynesis/__main__.py index a5d9c04f..ecb5fe69 100644 --- a/flexynesis/__main__.py +++ b/flexynesis/__main__.py @@ -1,13 +1,9 @@ +import argparse import os import sys -import argparse -import yaml import time -import random import warnings -import json -import tracemalloc -import psutil + from . import __version__ os.environ["OMP_NUM_THREADS"] = "1" @@ -16,46 +12,96 @@ def print_test_installation(): print("Test Installation:") print(" # Download and extract test dataset") - print(" curl -L -o dataset1.tgz https://bimsbstatic.mdc-berlin.de/akalin/buyar/flexynesis-benchmark-datasets/dataset1.tgz") + print( + " curl -L -o dataset1.tgz " + "https://bimsbstatic.mdc-berlin.de/akalin/buyar/" + "flexynesis-benchmark-datasets/dataset1.tgz" + ) print(" tar -xzvf dataset1.tgz") print() print(" # Test the installation (should finish within a minute on a typical CPU)") - print(" flexynesis --data_path dataset1 --model_class DirectPred --target_variables Erlotinib --hpo_iter 1 --features_top_percentile 5 --data_types gex,cnv") + print( + " flexynesis --data_path dataset1 --model_class DirectPred " + "--target_variables Erlotinib --hpo_iter 1 " + "--features_top_percentile 5 --data_types gex,cnv" + ) def print_help(): - print("usage: flexynesis [-h] --data_path DATA_PATH --model_class {DirectPred,supervised_vae,MultiTripletNetwork,CrossModalPred,GNN,RandomForest,SVM,XGBoost,RandomSurvivalForest} --data_types DATA_TYPES") + print( + "usage: flexynesis [-h] --data_path DATA_PATH --model_class " + "{DirectPred,supervised_vae,MultiTripletNetwork,CrossModalPred,GNN," + "RandomForest,SVM,XGBoost,RandomSurvivalForest} --data_types " + "DATA_TYPES" + ) print() print("Flexynesis model training interface") print() print("options:") print(" -h, --help show complete help with all options") print(" --data_path DATA_PATH") - print(" (Required) Path to the folder with train/test data files") - print(" --model_class {DirectPred,supervised_vae,MultiTripletNetwork,CrossModalPred,GNN,RandomForest,SVM,XGBoost,RandomSurvivalForest}") + print( + " (Required) Path to the folder with train/test data files" + ) + print( + " --model_class {DirectPred,supervised_vae,MultiTripletNetwork," + "CrossModalPred,GNN,RandomForest,SVM,XGBoost,RandomSurvivalForest}" + ) print(" (Required) The kind of model class to instantiate") print(" --data_types DATA_TYPES") - print(" (Required) Which omic data matrices to work on, comma-separated: e.g. 'gex,cnv'") - print(" --hpo_iter HPO_ITER Number of iterations for hyperparameter optimisation (default: 100)") + print( + " (Required) Which omic data matrices to work on, " + "comma-separated: e.g. 'gex,cnv'" + ) + print( + " --hpo_iter HPO_ITER Number of iterations for hyperparameter " + "optimisation (default: 100)" + ) print(" --device {auto,cuda,mps,cpu}") - print(" Device type: 'auto' (automatic detection), 'cuda' (NVIDIA GPU), 'mps' (Apple Silicon), 'cpu' (default: auto)") - print(" --use_gpu (Optional) DEPRECATED: Use --device instead. If set, attempts to use CUDA/GPU if available.") + print( + " Device type: 'auto' (automatic detection), 'cuda' " + "(NVIDIA GPU), 'mps' (Apple Silicon), 'cpu' (default: auto)" + ) + print( + " --use_gpu (Optional) DEPRECATED: Use --device instead. If " + "set, attempts to use CUDA/GPU if available." + ) print() print_test_installation() print() - print(" See the documentation for more details at https://bimsbstatic.mdc-berlin.de/akalin/buyar/flexynesis/site/getting_started/.") + print( + " See the documentation for more details at " + "https://bimsbstatic.mdc-berlin.de/akalin/buyar/flexynesis/site/getting_started/." + ) def print_full_help(): - print("usage: flexynesis [-h] --data_path DATA_PATH --model_class {DirectPred,supervised_vae,MultiTripletNetwork,CrossModalPred,GNN,RandomForest,SVM,XGBoost,RandomSurvivalForest} " - "[--gnn_conv_type {GC,GCN,SAGE}] [--target_variables TARGET_VARIABLES] [--covariates COVARIATES] [--surv_event_var SURV_EVENT_VAR] [--surv_time_var SURV_TIME_VAR] " - "[--config_path CONFIG_PATH] [--fusion_type {early,intermediate}] [--hpo_iter HPO_ITER] [--finetuning_samples FINETUNING_SAMPLES] " - "[--variance_threshold VARIANCE_THRESHOLD] [--correlation_threshold CORRELATION_THRESHOLD] [--restrict_to_features RESTRICT_TO_FEATURES] " - "[--subsample SUBSAMPLE] [--features_min FEATURES_MIN] [--features_top_percentile FEATURES_TOP_PERCENTILE] --data_types DATA_TYPES " - "[--input_layers INPUT_LAYERS] [--output_layers OUTPUT_LAYERS] [--outdir OUTDIR] [--prefix PREFIX] [--log_transform {True,False}] " - "[--early_stop_patience EARLY_STOP_PATIENCE] [--hpo_patience HPO_PATIENCE] [--val_size VAL_SIZE] [--use_cv] [--use_loss_weighting {True,False}] " - "[--evaluate_baseline_performance] [--threads THREADS] [--num_workers NUM_WORKERS] [--device {auto,cuda,mps,cpu}] [--use_gpu] [--feature_importance_method {IntegratedGradients,GradientShap,Both}] " - "[--disable_marker_finding] [--string_organism STRING_ORGANISM] [--string_node_name {gene_name,gene_id}] [--user_graph USER_GRAPH] [--safetensors]") + print( + "usage: flexynesis [-h] --data_path DATA_PATH --model_class " + "{DirectPred,supervised_vae,MultiTripletNetwork,CrossModalPred,GNN," + "RandomForest,SVM,XGBoost,RandomSurvivalForest} " + "[--gnn_conv_type {GC,GCN,SAGE}] [--target_variables TARGET_VARIABLES] " + "[--covariates COVARIATES] [--surv_event_var SURV_EVENT_VAR] " + "[--surv_time_var SURV_TIME_VAR] [--config_path CONFIG_PATH] " + "[--fusion_type {early,intermediate}] [--hpo_iter HPO_ITER] " + "[--finetuning_samples FINETUNING_SAMPLES] " + "[--variance_threshold VARIANCE_THRESHOLD] " + "[--correlation_threshold CORRELATION_THRESHOLD] " + "[--restrict_to_features RESTRICT_TO_FEATURES] [--subsample SUBSAMPLE] " + "[--features_min FEATURES_MIN] [--features_top_percentile " + "FEATURES_TOP_PERCENTILE] --data_types DATA_TYPES [--input_layers " + "INPUT_LAYERS] [--output_layers OUTPUT_LAYERS] [--outdir OUTDIR] " + "[--prefix PREFIX] [--log_transform {True,False}] " + "[--early_stop_patience EARLY_STOP_PATIENCE] [--hpo_patience " + "HPO_PATIENCE] [--val_size VAL_SIZE] [--use_cv] " + "[--use_loss_weighting {True,False}] " + "[--evaluate_baseline_performance] [--threads THREADS] " + "[--num_workers NUM_WORKERS] [--device {auto,cuda,mps,cpu}] [--use_gpu] " + "[--feature_importance_method {IntegratedGradients,GradientShap,Both}] " + "[--disable_marker_finding] [--string_organism STRING_ORGANISM] " + "[--string_node_name {gene_name,gene_id}] [--user_graph USER_GRAPH] " + "[--safetensors]" + ) print() print("Flexynesis model training interface") print() @@ -63,7 +109,10 @@ def print_full_help(): # --- NEW: inference-only flags (keep in full help) --- print(" --pretrained_model PRETRAINED_MODEL") - print(" Use a saved .pth/.safetensors model for inference (skip training)") + print( + " Use a saved .pth/.safetensors model for inference " + "(skip training)" + ) print(" --artifacts ARTIFACTS") print(" Path to training-time artifacts .joblib or .json") print(" --data_path_test DATA_PATH_TEST") @@ -73,83 +122,199 @@ def print_full_help(): # --- existing flags (keep full list) --- print(" -h, --help show this help message and exit") print(" --data_path DATA_PATH") - print(" (Required) Path to the folder with train/test data files") - print(" --model_class {DirectPred,supervised_vae,MultiTripletNetwork,CrossModalPred,GNN,RandomForest,SVM,XGBoost,RandomSurvivalForest}") + print( + " (Required) Path to the folder with train/test data files" + ) + print( + " --model_class {DirectPred,supervised_vae,MultiTripletNetwork," + "CrossModalPred,GNN,RandomForest,SVM,XGBoost,RandomSurvivalForest}" + ) print(" (Required) The kind of model class to instantiate") print(" --gnn_conv_type {GC,GCN,SAGE}") - print(" If model_class is set to GNN, choose which graph convolution type to use") + print( + " If model_class is set to GNN, choose which graph " + "convolution type to use" + ) print(" --target_variables TARGET_VARIABLES") - print(" (Optional if survival variables are not set to None). Which variables in 'clin.csv' to use for predictions, comma-separated if multiple") + print( + " (Optional if survival variables are not set to " + "None). Which variables in 'clin.csv' to use for predictions, " + "comma-separated if multiple" + ) print(" --covariates COVARIATES") - print(" Which variables in 'clin.csv' to be used as feature covariates, comma-separated if multiple") + print( + " Which variables in 'clin.csv' to be used as feature " + "covariates, comma-separated if multiple" + ) print(" --surv_event_var SURV_EVENT_VAR") - print(" Which column in 'clin.csv' to use as event/status indicator for survival modeling") + print( + " Which column in 'clin.csv' to use as " + "event/status indicator for survival modeling" + ) print(" --surv_time_var SURV_TIME_VAR") - print(" Which column in 'clin.csv' to use as time/duration indicator for survival modeling") + print( + " Which column in 'clin.csv' to use as time/duration " + "indicator for survival modeling" + ) print(" --config_path CONFIG_PATH") - print(" Optional path to an external hyperparameter configuration file in YAML format.") + print( + " Optional path to an external hyperparameter " + "configuration file in YAML format." + ) print(" --fusion_type {early,intermediate}") - print(" How to fuse the omics layers (default: intermediate)") - print(" --hpo_iter HPO_ITER Number of iterations for hyperparameter optimisation (default: 100)") + print( + " How to fuse the omics layers (default: intermediate)" + ) + print( + " --hpo_iter HPO_ITER Number of iterations for hyperparameter optimisation (default: 100)" + ) print(" --finetuning_samples FINETUNING_SAMPLES") - print(" Number of samples from the test dataset to use for fine-tuning the model. Set to 0 to disable fine-tuning (default: 0)") + print( + " Number of samples from the test dataset to use " + "for fine-tuning the model. Set to 0 to disable fine-tuning " + "(default: 0)" + ) print(" --variance_threshold VARIANCE_THRESHOLD") - print(" Variance threshold (as percentile) to drop low variance features (default is 1; set to 0 for no variance filtering)") + print( + " Variance threshold (as percentile) to drop low " + "variance features (default is 1; set to 0 for no variance filtering)" + ) print(" --correlation_threshold CORRELATION_THRESHOLD") - print(" Correlation threshold to drop highly redundant features (default is 0.8; set to 1 for no redundancy filtering)") + print( + " Correlation threshold to drop highly redundant " + "features (default is 0.8; set to 1 for no redundancy filtering)" + ) print(" --restrict_to_features RESTRICT_TO_FEATURES") - print(" Restrict the analysis to the list of features provided by the user (default is None)") + print( + " Restrict the analysis to the list of features " + "provided by the user (default is None)" + ) print(" --subsample SUBSAMPLE") - print(" Downsample training set to randomly drawn N samples for training. Disabled when set to 0 (default: 0)") + print( + " Downsample training set to randomly drawn N samples " + "for training. Disabled when set to 0 (default: 0)" + ) print(" --features_min FEATURES_MIN") - print(" Minimum number of features to retain after feature selection (default: 500)") + print( + " Minimum number of features to retain after feature selection (default: 500)" + ) print(" --features_top_percentile FEATURES_TOP_PERCENTILE") - print(" Top percentile features (among the features remaining after variance filtering and data cleanup) to retain after feature selection (default: 20)") + print( + " Top percentile features (among the features " + "remaining after variance filtering and data cleanup) to retain after " + "feature selection (default: 20)" + ) print(" --data_types DATA_TYPES") - print(" (Required) Which omic data matrices to work on, comma-separated: e.g. 'gex,cnv'") + print( + " (Required) Which omic data matrices to work on, comma-separated: e.g. 'gex,cnv'" + ) print(" --input_layers INPUT_LAYERS") - print(" If model_class is set to CrossModalPred, choose which data types to use as input/encoded layers. Comma-separated if multiple") + print( + " If model_class is set to CrossModalPred, choose " + "which data types to use as input/encoded layers. Comma-separated if " + "multiple" + ) print(" --output_layers OUTPUT_LAYERS") - print(" If model_class is set to CrossModalPred, choose which data types to use as output/decoded layers. Comma-separated if multiple") - print(" --outdir OUTDIR Path to the output folder to save the model outputs (default: current working directory)") + print( + " If model_class is set to CrossModalPred, choose " + "which data types to use as output/decoded layers. Comma-separated if " + "multiple" + ) + print( + " --outdir OUTDIR Path to the output folder to save the model " + "outputs (default: current working directory)" + ) print(" --prefix PREFIX Job prefix to use for output files (default: 'job')") print(" --log_transform {True,False}") - print(" whether to apply log-transformation to input data matrices (default: False)") + print( + " whether to apply log-transformation to input " + "data matrices (default: False)" + ) print(" --early_stop_patience EARLY_STOP_PATIENCE") - print(" How many epochs to wait when no improvements in validation loss is observed (default 10; set to -1 to disable early stopping)") + print( + " How many epochs to wait when no improvements in " + "validation loss is observed (default 10; set to -1 to disable early " + "stopping)" + ) print(" --hpo_patience HPO_PATIENCE") - print(" How many hyperparameter optimisation iterations to wait for when no improvements are observed (default is 10; set to 0 to disable early stopping)") - print(" --val_size VAL_SIZE Proportion of training data to be used as validation split (default: 0.2)") - print(" --use_cv (Optional) If set, a 5-fold cross-validation training will be done. Otherwise, a single training on 80 percent of the dataset is done.") + print( + " How many hyperparameter optimisation iterations " + "to wait for when no improvements are observed (default is 10; set to 0 " + "to disable early stopping)" + ) + print( + " --val_size VAL_SIZE Proportion of training data to be used as " + "validation split (default: 0.2)" + ) + print( + " --use_cv (Optional) If set, a 5-fold cross-validation " + "training will be done. Otherwise, a single training on 80 percent of " + "the dataset is done." + ) print(" --use_loss_weighting {True,False}") - print(" whether to apply loss-balancing using uncertainty weights method (default: True)") + print( + " whether to apply loss-balancing using uncertainty " + "weights method (default: True)" + ) print(" --evaluate_baseline_performance") - print(" whether to run Random Forest + SVMs to see the performance of off-the-shelf tools on the same dataset") - print(" --threads THREADS (Optional) How many threads to use when using CPU (default is 4)") + print( + " whether to run Random Forest + SVMs to see the " + "performance of off-the-shelf tools on the same dataset" + ) + print( + " --threads THREADS (Optional) How many threads to use when using " + "CPU (default is 4)" + ) print(" --num_workers NUM_WORKERS") - print(" (Optional) How many workers to use for model training (default is 0)") + print( + " (Optional) How many workers to use for model " + "training (default is 0)" + ) print(" --device {auto,cuda,mps,cpu}") - print(" Device type: 'auto' (automatic detection), 'cuda' (NVIDIA GPU), 'mps' (Apple Silicon), 'cpu' (default: auto)") - print(" --use_gpu (Optional) DEPRECATED: Use --device instead. If set, attempts to use CUDA/GPU if available.") + print( + " Device type: 'auto' (automatic detection), 'cuda' " + "(NVIDIA GPU), 'mps' (Apple Silicon), 'cpu' (default: auto)" + ) + print( + " --use_gpu (Optional) DEPRECATED: Use --device instead. If " + "set, attempts to use CUDA/GPU if available." + ) print(" --feature_importance_method {IntegratedGradients,GradientShap,Both}") - print(" Choose feature importance score method (default: IntegratedGradients)") + print( + " Choose feature importance score method " + "(default: IntegratedGradients)" + ) print(" --disable_marker_finding") - print(" (Optional) If set, marker discovery after model training is disabled.") + print( + " (Optional) If set, marker discovery after model " + "training is disabled." + ) print(" --string_organism STRING_ORGANISM") print(" STRING DB organism id. (default: 9606)") print(" --string_node_name {gene_name,gene_id}") print(" Type of node name. (default: gene_name)") print(" --user_graph USER_GRAPH") - print(" Path to user-provided gene-gene interaction network file.") + print( + " Path to user-provided gene-gene interaction network file." + ) print(" Must have at least 3 columns: GeneA, GeneB, Score.") - print(" If provided, this will be used instead of STRING DB.") + print( + " If provided, this will be used instead of STRING DB." + ) print(" --safetensors") - print(" If set, the model will be saved in the SafeTensors format and the artifacts saved as JSON.") + print( + " If set, the model will be saved in the " + "SafeTensors format and the artifacts saved as JSON." + ) print(" Default is False.") print() print_test_installation() print() - print(" See the documentation for more details at https://bimsbstatic.mdc-berlin.de/akalin/buyar/flexynesis/site/getting_started/.") + print( + " See the documentation for more details at " + "https://bimsbstatic.mdc-berlin.de/akalin/buyar/flexynesis/site/" + "getting_started/." + ) def main(): @@ -318,7 +483,9 @@ def main(): Inference only: ``` - flexynesis --pretrained_model ./outputs/best_model.pth --artifacts ./outputs/artifacts.joblib --data_path_test ./data_test --outdir ./predictions --prefix run1 + flexynesis --pretrained_model ./outputs/best_model.pth --artifacts \ + ./outputs/artifacts.joblib --data_path_test ./data_test --outdir \ + ./predictions --prefix run1 ``` """ @@ -327,112 +494,312 @@ def main(): if len(sys.argv) == 1: print_help() return - if any(arg in ['-h', '--help'] for arg in sys.argv): + if any(arg in ["-h", "--help"] for arg in sys.argv): print_full_help() return # ------------- Parser (lightweight) ------------- parser = argparse.ArgumentParser( description="Flexynesis model training interface", - formatter_class=argparse.ArgumentDefaultsHelpFormatter + formatter_class=argparse.ArgumentDefaultsHelpFormatter, ) - parser.add_argument("-v", "--version", action="version", version=f"%(prog)s {__version__}") + parser.add_argument( + "-v", "--version", action="version", version=f"%(prog)s {__version__}" + ) # Existing core flags (made not-required here; enforced conditionally below) - parser.add_argument("--data_path", type=str, required=False, - help="Path to the folder with train/test data files") - parser.add_argument("--model_class", type=str, required=False, - choices=["DirectPred", "supervised_vae", "MultiTripletNetwork", "CrossModalPred", "GNN", "RandomForest", "SVM", "XGBoost", "RandomSurvivalForest"], - help="The kind of model class to instantiate") - parser.add_argument("--gnn_conv_type", type=str, choices=["GC", "GCN", "SAGE"], - help="If model_class is set to GNN, choose which graph convolution type to use") - parser.add_argument("--target_variables", type=str, default=None, - help="(Optional if survival variables are not set to None). Which variables in 'clin.csv' to use for predictions, comma-separated if multiple") - parser.add_argument("--covariates", type=str, default=None, - help="Which variables in 'clin.csv' to be used as feature covariates, comma-separated if multiple") - parser.add_argument("--surv_event_var", type=str, default=None, - help="Which column in 'clin.csv' to use as event/status indicator for survival modeling") - parser.add_argument("--surv_time_var", type=str, default=None, - help="Which column in 'clin.csv' to use as time/duration indicator for survival modeling") - parser.add_argument('--config_path', type=str, default=None, - help='Optional path to an external hyperparameter configuration file in YAML format.') - parser.add_argument("--fusion_type", type=str, choices=["early", "intermediate"], default='intermediate', - help="How to fuse the omics layers") - parser.add_argument("--hpo_iter", type=int, default=100, - help="Number of iterations for hyperparameter optimisation") - parser.add_argument("--finetuning_samples", type=int, default=0, - help="Number of samples from the test dataset to use for fine-tuning the model. Set to 0 to disable fine-tuning") - parser.add_argument("--variance_threshold", type=float, default=1, - help="Variance threshold (as percentile) to drop low variance features (default is 1; set to 0 for no variance filtering)") - parser.add_argument("--correlation_threshold", type=float, default=0.8, - help="Correlation threshold to drop highly redundant features (default is 0.8; set to 1 for no redundancy filtering)") - parser.add_argument("--restrict_to_features", type=str, default=None, - help="Restrict the analyis to the list of features provided by the user (default is None)") - parser.add_argument("--subsample", type=int, default=0, - help="Downsample training set to randomly drawn N samples for training. Disabled when set to 0") - parser.add_argument("--features_min", type=int, default=500, - help="Minimum number of features to retain after feature selection") - parser.add_argument("--features_top_percentile", type=float, default=20, - help="Top percentile features (among the features remaining after variance filtering and data cleanup) to retain after feature selection") - parser.add_argument("--data_types", type=str, required=False, - help="Which omic data matrices to work on, comma-separated: e.g. 'gex,cnv'") - parser.add_argument("--input_layers", type=str, default=None, - help="If model_class is set to CrossModalPred, choose which data types to use as input/encoded layers. Comma-separated if multiple") - parser.add_argument("--output_layers", type=str, default=None, - help="If model_class is set to CrossModalPred, choose which data types to use as output/decoded layers. Comma-separated if multiple") - parser.add_argument("--outdir", type=str, default=os.getcwd(), - help="Path to the output folder to save the model outputs") - parser.add_argument("--prefix", type=str, default='job', - help="Job prefix to use for output files") - parser.add_argument("--log_transform", type=str, choices=['True', 'False'], default='False', - help="whether to apply log-transformation to input data matrices") - parser.add_argument("--early_stop_patience", type=int, default=10, - help="How many epochs to wait when no improvements in validation loss is observed (default 10; set to -1 to disable early stopping)") - parser.add_argument("--hpo_patience", type=int, default=20, - help="How many hyperparamater optimisation iterations to wait for when no improvements are observed (default is 10; set to 0 to disable early stopping)") - parser.add_argument("--val_size", type=float, default=0.2, - help="Proportion of training data to be used as validation split (default: 0.2)") - parser.add_argument("--use_cv", action="store_true", - help="(Optional) If set, the a 5-fold cross-validation training will be done. Otherwise, a single trainig on 80 percent of the dataset is done.") - parser.add_argument("--use_loss_weighting", type=str, choices=['True', 'False'], default='True', - help="whether to apply loss-balancing using uncertainty weights method") - parser.add_argument("--evaluate_baseline_performance", action="store_true", - help="whether to run Random Forest + SVMs to see the performance of off-the-shelf tools on the same dataset") - parser.add_argument("--threads", type=int, default=4, - help="(Optional) How many threads to use when using CPU (default is 4)") - parser.add_argument("--num_workers", type=int, default=0, - help="(Optional) How many workers to use for model training (default is 0)") - parser.add_argument("--use_gpu", action="store_true", - help="(Optional) DEPRECATED: Use --device instead. If set, attempts to use CUDA/GPU if available.") - parser.add_argument("--device", type=str, - choices=["auto", "cuda", "mps", "cpu"], default="auto", - help="Device type: 'auto' (automatic detection), 'cuda' (NVIDIA GPU), 'mps' (Apple Silicon), 'cpu'") - parser.add_argument("--feature_importance_method", type=str, - choices=["IntegratedGradients", "GradientShap", "Both"], default="IntegratedGradients", - help="Choose feature importance score method") - parser.add_argument("--disable_marker_finding", action="store_true", - help="(Optional) If set, marker discovery after model training is disabled.") + parser.add_argument( + "--data_path", + type=str, + required=False, + help="Path to the folder with train/test data files", + ) + parser.add_argument( + "--model_class", + type=str, + required=False, + choices=[ + "DirectPred", + "supervised_vae", + "MultiTripletNetwork", + "CrossModalPred", + "GNN", + "RandomForest", + "SVM", + "XGBoost", + "RandomSurvivalForest", + ], + help="The kind of model class to instantiate", + ) + parser.add_argument( + "--gnn_conv_type", + type=str, + choices=["GC", "GCN", "SAGE"], + help="If model_class is set to GNN, choose which graph convolution type to use", + ) + parser.add_argument( + "--target_variables", + type=str, + default=None, + help="(Optional if survival variables are not set to None). Which variables " + "in 'clin.csv' to use for predictions, comma-separated if multiple", + ) + parser.add_argument( + "--covariates", + type=str, + default=None, + help="Which variables in 'clin.csv' to be used as feature covariates, " + "comma-separated if multiple", + ) + parser.add_argument( + "--surv_event_var", + type=str, + default=None, + help="Which column in 'clin.csv' to use as event/status indicator for " + "survival modeling", + ) + parser.add_argument( + "--surv_time_var", + type=str, + default=None, + help="Which column in 'clin.csv' to use as time/duration indicator for " + "survival modeling", + ) + parser.add_argument( + "--config_path", + type=str, + default=None, + help="Optional path to an external hyperparameter configuration file in " + "YAML format.", + ) + parser.add_argument( + "--fusion_type", + type=str, + choices=["early", "intermediate"], + default="intermediate", + help="How to fuse the omics layers", + ) + parser.add_argument( + "--hpo_iter", + type=int, + default=100, + help="Number of iterations for hyperparameter optimisation", + ) + parser.add_argument( + "--finetuning_samples", + type=int, + default=0, + help="Number of samples from the test dataset to use for fine-tuning the " + "model. Set to 0 to disable fine-tuning", + ) + parser.add_argument( + "--variance_threshold", + type=float, + default=1, + help="Variance threshold (as percentile) to drop low variance features " + "(default is 1; set to 0 for no variance filtering)", + ) + parser.add_argument( + "--correlation_threshold", + type=float, + default=0.8, + help="Correlation threshold to drop highly redundant features (default is " + "0.8; set to 1 for no redundancy filtering)", + ) + parser.add_argument( + "--restrict_to_features", + type=str, + default=None, + help="Restrict the analyis to the list of features provided by the user (default is None)", + ) + parser.add_argument( + "--subsample", + type=int, + default=0, + help="Downsample training set to randomly drawn N samples for training. Disabled when set to 0", + ) + parser.add_argument( + "--features_min", + type=int, + default=500, + help="Minimum number of features to retain after feature selection", + ) + parser.add_argument( + "--features_top_percentile", + type=float, + default=20, + help="Top percentile features (among the features remaining after variance " + "filtering and data cleanup) to retain after feature selection", + ) + parser.add_argument( + "--data_types", + type=str, + required=False, + help="Which omic data matrices to work on, comma-separated: e.g. 'gex,cnv'", + ) + parser.add_argument( + "--input_layers", + type=str, + default=None, + help="If model_class is set to CrossModalPred, choose which data types to " + "use as input/encoded layers. Comma-separated if multiple", + ) + parser.add_argument( + "--output_layers", + type=str, + default=None, + help="If model_class is set to CrossModalPred, choose which data types to " + "use as output/decoded layers. Comma-separated if multiple", + ) + parser.add_argument( + "--outdir", + type=str, + default=os.getcwd(), + help="Path to the output folder to save the model outputs", + ) + parser.add_argument( + "--prefix", + type=str, + default="job", + help="Job prefix to use for output files", + ) + parser.add_argument( + "--log_transform", + type=str, + choices=["True", "False"], + default="False", + help="whether to apply log-transformation to input data matrices", + ) + parser.add_argument( + "--early_stop_patience", + type=int, + default=10, + help="How many epochs to wait when no improvements in validation loss is " + "observed (default 10; set to -1 to disable early stopping)", + ) + parser.add_argument( + "--hpo_patience", + type=int, + default=20, + help="How many hyperparamater optimisation iterations to wait for when no " + "improvements are observed (default is 10; set to 0 to disable early " + "stopping)", + ) + parser.add_argument( + "--val_size", + type=float, + default=0.2, + help="Proportion of training data to be used as validation split (default: 0.2)", + ) + parser.add_argument( + "--use_cv", + action="store_true", + help="(Optional) If set, the a 5-fold cross-validation training will be done. " + "Otherwise, a single trainig on 80 percent of the dataset is done.", + ) + parser.add_argument( + "--use_loss_weighting", + type=str, + choices=["True", "False"], + default="True", + help="whether to apply loss-balancing using uncertainty weights method", + ) + parser.add_argument( + "--evaluate_baseline_performance", + action="store_true", + help="whether to run Random Forest + SVMs to see the performance of " + "off-the-shelf tools on the same dataset", + ) + parser.add_argument( + "--threads", + type=int, + default=4, + help="(Optional) How many threads to use when using CPU (default is 4)", + ) + parser.add_argument( + "--num_workers", + type=int, + default=0, + help="(Optional) How many workers to use for model training (default is 0)", + ) + parser.add_argument( + "--use_gpu", + action="store_true", + help="(Optional) DEPRECATED: Use --device instead. If set, attempts to use CUDA/GPU if available.", + ) + parser.add_argument( + "--device", + type=str, + choices=["auto", "cuda", "mps", "cpu"], + default="auto", + help="Device type: 'auto' (automatic detection), 'cuda' (NVIDIA GPU), 'mps' (Apple Silicon), 'cpu'", + ) + parser.add_argument( + "--feature_importance_method", + type=str, + choices=["IntegratedGradients", "GradientShap", "Both"], + default="IntegratedGradients", + help="Choose feature importance score method", + ) + parser.add_argument( + "--disable_marker_finding", + action="store_true", + help="(Optional) If set, marker discovery after model training is disabled.", + ) # GNN args. - parser.add_argument("--string_organism", type=int, default=9606, - help="STRING DB organism id.") - parser.add_argument("--string_node_name", type=str, choices=["gene_name", "gene_id"], default="gene_name", - help="Type of node name.") - parser.add_argument("--user_graph", type=str, default=None, - help="Path to user-provided gene-gene interaction network file. " - "Must have at least 3 columns: GeneA, GeneB, Score. " - "If provided, this will be used instead of STRING DB.") + parser.add_argument( + "--string_organism", + type=int, + default=9606, + help="STRING DB organism id.", + ) + parser.add_argument( + "--string_node_name", + type=str, + choices=["gene_name", "gene_id"], + default="gene_name", + help="Type of node name.", + ) + parser.add_argument( + "--user_graph", + type=str, + default=None, + help="Path to user-provided gene-gene interaction network file. Must have " + "at least 3 columns: GeneA, GeneB, Score. If provided, this will be " + "used instead of STRING DB.", + ) # safetensors args - parser.add_argument("--safetensors", action="store_true", - help="If set, use SafeTensors + JSON artifacts for save/load (training and inference). Default is False.") + parser.add_argument( + "--safetensors", + action="store_true", + help="If set, use SafeTensors + JSON artifacts for save/load (training and " + "inference). Default is False.", + ) # NEW: inference flags - parser.add_argument("--pretrained_model", type=str, default=None, - help="Path to a saved model (.pth/.safetensors) to use for inference") - parser.add_argument("--artifacts", type=str, default=None, - help="Path to artifacts .joblib or .json saved during training") - parser.add_argument("--data_path_test", type=str, default=None, - help="Folder with test-only dataset for inference") - parser.add_argument("--join_key", type=str, default="JoinKey", - help="Column name in 'clin.csv' (test metadata) used to join sample IDs") + parser.add_argument( + "--pretrained_model", + type=str, + default=None, + help="Path to a saved model (.pth/.safetensors) to use for inference", + ) + parser.add_argument( + "--artifacts", + type=str, + default=None, + help="Path to artifacts .joblib or .json saved during training", + ) + parser.add_argument( + "--data_path_test", + type=str, + default=None, + help="Folder with test-only dataset for inference", + ) + parser.add_argument( + "--join_key", + type=str, + default="JoinKey", + help="Column name in 'clin.csv' (test metadata) used to join sample IDs", + ) args = parser.parse_args() @@ -445,28 +812,41 @@ def main(): # Only require core training flags if NOT doing inference in_infer = bool(args.pretrained_model) if not in_infer: - missing = [k for k in ("data_path", "model_class", "data_types") if not getattr(args, k, None)] + missing = [ + k + for k in ("data_path", "model_class", "data_types") + if not getattr(args, k, None) + ] if missing: - parser.error("the following arguments are required in training mode: " + - ", ".join(f"--{m}" for m in missing)) + parser.error( + "the following arguments are required in training mode: " + + ", ".join(f"--{m}" for m in missing) + ) # ---------- Inference mode: early exit path ---------- if args.pretrained_model and args.artifacts and args.data_path_test: import torch - from .utils import get_optimal_device, create_device_from_string # quick existence checks if not os.path.exists(args.pretrained_model): - raise FileNotFoundError(f"--pretrained_model not found: {args.pretrained_model}") + raise FileNotFoundError( + f"--pretrained_model not found: {args.pretrained_model}" + ) if not os.path.exists(args.artifacts): raise FileNotFoundError(f"--artifacts not found: {args.artifacts}") # Handle device selection for inference (same logic as training) if args.use_gpu: - warnings.warn("--use_gpu is deprecated. Use --device instead.", DeprecationWarning) + warnings.warn( + "--use_gpu is deprecated. Use --device instead.", + DeprecationWarning, + ) if args.device != "auto": device_preference = args.device - print(f"[WARN] Both --use_gpu and --device {args.device} specified. Using --device {args.device}.") + print( + f"[WARN] Both --use_gpu and --device {args.device} specified. " + f"Using --device {args.device}." + ) else: # Let auto-detection find the best GPU device (CUDA or MPS) device_preference = "auto" @@ -479,21 +859,26 @@ def main(): # Check model file content type from .inference import check_model_type + model_format = check_model_type(args.pretrained_model) if args.safetensors and model_format != "safetensors": - raise ValueError(f"[ERROR] The file {args.pretrained_model} is not a valid safetensors file.") + raise ValueError( + f"[ERROR] The file {args.pretrained_model} is not a valid safetensors " + f"file." + ) # Route to safetensors reconstruction or standard torch.load if model_format == "safetensors": from .inference import reconstruct_model + # Derive config path from model basename model_base = os.path.splitext(args.pretrained_model)[0] - config_path = model_base + '_config.json' + config_path = model_base + "_config.json" if not os.path.exists(config_path): # Try alongside the safetensors file with standard naming - base = args.pretrained_model.replace('.final_model.safetensors', '') - config_path = base + '.final_model_config.json' + base = args.pretrained_model.replace(".final_model.safetensors", "") + config_path = base + ".final_model_config.json" if not os.path.exists(config_path): raise FileNotFoundError( f"Cannot find config JSON for safetensors model. " @@ -508,12 +893,20 @@ def main(): else: # Standard .pth load — robust across PyTorch versions try: - model = torch.load(args.pretrained_model, map_location=device, weights_only=False) + model = torch.load( + args.pretrained_model, + map_location=device, + weights_only=False, + ) except TypeError: model = torch.load(args.pretrained_model, map_location=device) except Exception: try: - model = torch.load(args.pretrained_model, map_location=device, weights_only=True) + model = torch.load( + args.pretrained_model, + map_location=device, + weights_only=True, + ) except TypeError: model = torch.load(args.pretrained_model, map_location=device) @@ -524,89 +917,115 @@ def main(): # Load test data using DataImporterInference from .data import DataImporterInference - print('[INFO] Loading test data for inference...') + + print("[INFO] Loading test data for inference...") importer = DataImporterInference( test_data_path=args.data_path_test, artifacts_path=args.artifacts, - verbose=True + verbose=True, ) test_dataset = importer.import_data() # Convert to GNN dataset if needed - if args.model_class == 'GNN': + if args.model_class == "GNN": print("[INFO] Overlaying the dataset with network data from STRINGDB") - from .main import STRING - from .data import MultiOmicDatasetNW + from .data import STRING, MultiOmicDatasetNW + # Get STRING organism from artifacts - string_organism = importer.artifacts.get('string_organism', 9606) # default human - string_node_name = importer.artifacts.get('string_node_name', 'HGNC') + string_organism = importer.artifacts.get( + "string_organism", 9606 + ) # default human + string_node_name = importer.artifacts.get("string_node_name", "HGNC") # Load STRING network obj = STRING( - os.path.join(args.data_path_test, '_'.join(['processed', args.prefix])), + os.path.join(args.data_path_test, "_".join(["processed", args.prefix])), string_organism, - string_node_name + string_node_name, ) # Get modality order from artifacts for consistent feature stacking - modality_order = importer.artifacts.get("original_modalities", importer.artifacts.get("data_types")) - test_dataset = MultiOmicDatasetNW(test_dataset, obj.graph_df, modality_order=modality_order) + modality_order = importer.artifacts.get( + "original_modalities", importer.artifacts.get("data_types") + ) + test_dataset = MultiOmicDatasetNW( + test_dataset, obj.graph_df, modality_order=modality_order + ) train_dataset = None # No training data in inference mode # Move dataset to same device as model - if hasattr(test_dataset, 'to_device'): + if hasattr(test_dataset, "to_device"): test_dataset.to_device(device) - print(f'[INFO] Test dataset loaded: {len(test_dataset.samples)} samples') - # Continue to evaluation section (skip training) + print(f"[INFO] Test dataset loaded: {len(test_dataset.samples)} samples") - # ------------- Heavy imports only when training ------------- - from .utils import evaluate_baseline_performance, evaluate_baseline_survival_performance, get_predicted_labels, evaluate_wrapper, get_optimal_device, get_device_memory_info, create_device_from_string + # Import evaluation utilities needed in the shared evaluation section below + import pandas as pd # noqa: F401 + + from .utils import evaluate_wrapper, get_predicted_labels # noqa: F401 + + # Continue to evaluation section (skip training) + + # ------------- Heavy imports only when training or in inference with evaluation --- if not (args.pretrained_model and args.artifacts and args.data_path_test): - import flexynesis - from lightning import seed_everything - import lightning as pl - from typing import NamedTuple - import torch - import pandas as pd - from safetensors.torch import save_file + import json # noqa: F401 + import tracemalloc # noqa: F401 - # models - from .models.direct_pred import DirectPred - from .models.supervised_vae import supervised_vae - from .models.triplet_encoder import MultiTripletNetwork - from .models.crossmodal_pred import CrossModalPred - from .models.gnn_early import GNN + import torch # noqa: F401 + from safetensors.torch import save_file # noqa: F401 # data + utils - from .data import STRING, MultiOmicDatasetNW, DataImporter - from .main import HyperparameterTuning, FineTuner - import tracemalloc, psutil - import json + from .data import (STRING, DataImporter, # noqa: F401 + MultiOmicDatasetNW) + from .main import FineTuner, HyperparameterTuning # noqa: F401 + from .models.crossmodal_pred import CrossModalPred # noqa: F401 + # models + from .models.direct_pred import DirectPred # noqa: F401 + from .models.gnn_early import GNN # noqa: F401 + from .models.supervised_vae import supervised_vae # noqa: F401 + from .models.triplet_encoder import MultiTripletNetwork # noqa: F401 + from .utils import evaluate_baseline_performance # noqa: F401 + from .utils import (evaluate_baseline_survival_performance, + get_device_memory_info, get_optimal_device) # --------- Sanity checks on args --------- # 1. survival variables consistency if (args.surv_event_var is None) != (args.surv_time_var is None): - parser.error("Both --surv_event_var and --surv_time_var must be provided together or left as None.") + parser.error( + "Both --surv_event_var and --surv_time_var must be provided together or left as None." + ) # 2. required variables for model classes if args.model_class not in ("supervised_vae", "CrossModalPred"): if not any([args.target_variables, args.surv_event_var]): - parser.error("When selecting a model other than 'supervised_vae' or 'CrossModalPred', you must provide at least one of --target_variables, or survival variables (--surv_event_var and --surv_time_var)") + parser.error( + "When selecting a model other than 'supervised_vae' or " + "'CrossModalPred', you must provide at least one of " + "--target_variables, or survival variables (--surv_event_var and " + "--surv_time_var)" + ) # 3. Check for compatibility of fusion_type with CrossModalPred if args.fusion_type == "early": - if args.model_class == 'CrossModalPred': - parser.error("The 'CrossModalPred' model cannot be used with early fusion type. " - "Use --fusion_type intermediate instead.") - + if args.model_class == "CrossModalPred": + parser.error( + "The 'CrossModalPred' model cannot be used with early fusion " + "type. Use --fusion_type intermediate instead." + ) # 4. Handle device selection with MPS support # Support legacy --use_gpu flag for backward compatibility if args.use_gpu: - warnings.warn("--use_gpu is deprecated. Use --device instead.", DeprecationWarning) + warnings.warn( + "--use_gpu is deprecated. Use --device instead.", + DeprecationWarning, + ) # If --device is not explicitly set (still at default auto), let auto-detection handle it if args.device != "auto": - # If both --use_gpu and explicit --device are provided, respect --device but warn + # If both --use_gpu and explicit --device are provided, respect + # --device but warn device_preference = args.device - print(f"[WARN] Both --use_gpu and --device {args.device} specified. Using --device {args.device}.") + print( + f"[WARN] Both --use_gpu and --device {args.device} specified. " + f"Using --device {args.device}." + ) else: # Let auto-detection find the best GPU device (CUDA or MPS) device_preference = "auto" @@ -618,18 +1037,21 @@ def main(): # Print device information print(f"[INFO] Using device: {device_str}") - if device_str != 'cpu': + if device_str != "cpu": memory_info = get_device_memory_info(device_str) print(f"[INFO] Device name: {memory_info['device_name']}") - if device_str == 'cuda': + if device_str == "cuda": print(f"[INFO] Available CUDA devices: {memory_info['device_count']}") # gnn - if args.model_class == 'GNN': + if args.model_class == "GNN": if not args.gnn_conv_type: - warnings.warn("\n\n!!! When running GNN, set --gnn_conv_type (GC/GCN/SAGE). Falling back to GC !!!\n") + warnings.warn( + "\n\n!!! When running GNN, set --gnn_conv_type (GC/GCN/SAGE). " + "Falling back to GC !!!\n" + ) time.sleep(3) - gnn_conv_type = 'GC' + gnn_conv_type = "GC" else: gnn_conv_type = args.gnn_conv_type else: @@ -638,27 +1060,38 @@ def main(): # CrossModalPred IO layers input_layers = args.input_layers output_layers = args.output_layers - datatypes = args.data_types.strip().split(',') - if args.model_class == 'CrossModalPred': + datatypes = args.data_types.strip().split(",") + if args.model_class == "CrossModalPred": if args.input_layers: - input_layers = input_layers.strip().split(',') + input_layers = input_layers.strip().split(",") if not all(layer in datatypes for layer in input_layers): - raise ValueError(f"Input layers {input_layers} are not a valid subset of the data types: ({datatypes}).") + raise ValueError( + f"Input layers {input_layers} are not a valid subset of the " + f"data types: ({datatypes})." + ) if args.output_layers: - output_layers = output_layers.strip().split(',') + output_layers = output_layers.strip().split(",") if not all(layer in datatypes for layer in output_layers): - raise ValueError(f"Output layers {output_layers} are not a valid subset of the data types: ({datatypes}).") + raise ValueError( + f"Output layers {output_layers} are not a valid subset of the " + f"data types: ({datatypes})." + ) # paths if not os.path.exists(args.data_path): - raise FileNotFoundError(f"Input --data_path doesn't exist at: {args.data_path}") + raise FileNotFoundError( + f"Input --data_path doesn't exist at: {args.data_path}" + ) if not os.path.exists(args.outdir): raise FileNotFoundError(f"Path to --outdir doesn't exist at: {args.outdir}") available_models = { "DirectPred": (DirectPred, "DirectPred"), "supervised_vae": (supervised_vae, "supervised_vae"), - "MultiTripletNetwork": (MultiTripletNetwork, "MultiTripletNetwork"), + "MultiTripletNetwork": ( + MultiTripletNetwork, + "MultiTripletNetwork", + ), "CrossModalPred": (CrossModalPred, "CrossModalPred"), "GNN": (GNN, "GNN"), "RandomForest": ("RandomForest", None), @@ -673,16 +1106,19 @@ def main(): model_class, config_name = model_info # Set concatenate to True to use early fusion, otherwise intermediate - concatenate = args.fusion_type == 'early' and args.model_class != 'GNN' + concatenate = args.fusion_type == "early" and args.model_class != "GNN" # covariates if args.covariates: - if args.model_class == 'GNN': # Covariates not yet supported for GNNs - warnings.warn("\n\n!!! Covariates are currently not supported for GNN models, they will be ignored. !!!\n") + if args.model_class == "GNN": # Covariates not yet supported for GNNs + warnings.warn( + "\n\n!!! Covariates are currently not supported for GNN models, " + "they will be ignored. !!!\n" + ) time.sleep(3) covariates = None else: - covariates = args.covariates.strip().split(',') + covariates = args.covariates.strip().split(",") else: covariates = None @@ -691,101 +1127,157 @@ def main(): data_types=datatypes, covariates=covariates, concatenate=concatenate, - log_transform=args.log_transform == 'True', + log_transform=args.log_transform == "True", variance_threshold=args.variance_threshold / 100, correlation_threshold=args.correlation_threshold, restrict_to_features=args.restrict_to_features, min_features=args.features_min, top_percentile=args.features_top_percentile, - processed_dir='_'.join(['processed', args.prefix]), - downsample=args.subsample + processed_dir="_".join(["processed", args.prefix]), + downsample=args.subsample, ) # import data tracemalloc.start() - process = psutil.Process(os.getpid()) - t1 = time.time() train_dataset, test_dataset = data_importer.import_data() - data_import_time = time.time() - t1 - data_import_ram = process.memory_info().rss - # classical ML baselines if args.model_class == "XGBoost": try: - from xgboost import XGBClassifier + from xgboost import XGBClassifier # noqa: F401 except Exception: raise ImportError( - "XGBoost is not available. On macOS, install the OpenMP runtime: brew install libomp" + "XGBoost is not available. On macOS, install the OpenMP runtime: " + "brew install libomp" ) if args.model_class in ["RandomForest", "SVM", "XGBoost"]: if args.target_variables: - var = args.target_variables.strip().split(',')[0] + var = args.target_variables.strip().split(",")[0] print(f"Training {args.model_class} on variable: {var}") metrics, predictions = evaluate_baseline_performance( - train_dataset, test_dataset, variable_name=var, methods=[args.model_class], n_folds=5, n_jobs=args.threads + train_dataset, + test_dataset, + variable_name=var, + methods=[args.model_class], + n_folds=5, + n_jobs=args.threads, + ) + metrics.to_csv( + os.path.join(args.outdir, ".".join([args.prefix, "stats.csv"])), + header=True, + index=False, + ) + predictions.to_csv( + os.path.join( + args.outdir, + ".".join([args.prefix, "predicted_labels.csv"]), + ), + header=True, + index=False, ) - metrics.to_csv(os.path.join(args.outdir, '.'.join([args.prefix, 'stats.csv'])), header=True, index=False) - predictions.to_csv(os.path.join(args.outdir, '.'.join([args.prefix, 'predicted_labels.csv'])), header=True, index=False) print(f"{args.model_class} evaluation complete. Results saved.") sys.exit(0) else: - raise ValueError("At least one target variable is required to run RandomForest/SVM/XGBoost models. Set --target_variables") + raise ValueError( + "At least one target variable is required to run " + "RandomForest/SVM/XGBoost models. Set --target_variables" + ) if args.model_class == "RandomSurvivalForest": if args.surv_event_var and args.surv_time_var: - print(f"Training {args.model_class} on survival variables: {args.surv_event_var} and {args.surv_time_var}") + print( + f"Training {args.model_class} on survival variables: " + f"{args.surv_event_var} and {args.surv_time_var}" + ) metrics, predictions = evaluate_baseline_survival_performance( - train_dataset, test_dataset, args.surv_time_var, args.surv_event_var, n_folds=5, n_jobs=int(args.threads) + train_dataset, + test_dataset, + args.surv_time_var, + args.surv_event_var, + n_folds=5, + n_jobs=int(args.threads), + ) + metrics.to_csv( + os.path.join(args.outdir, ".".join([args.prefix, "stats.csv"])), + header=True, + index=False, + ) + predictions.to_csv( + os.path.join( + args.outdir, + ".".join([args.prefix, "predicted_labels.csv"]), + ), + header=True, + index=False, ) - metrics.to_csv(os.path.join(args.outdir, '.'.join([args.prefix, 'stats.csv'])), header=True, index=False) - predictions.to_csv(os.path.join(args.outdir, '.'.join([args.prefix, 'predicted_labels.csv'])), header=True, index=False) print(f"{args.model_class} evaluation complete. Results saved.") sys.exit(0) else: - raise ValueError("Missing survival variables. Set --surv_event_var --surv_time_var") + raise ValueError( + "Missing survival variables. Set --surv_event_var --surv_time_var" + ) # GNN overlay (training mode only - inference mode already handled this) - if args.model_class == 'GNN' and train_dataset is not None: + if args.model_class == "GNN" and train_dataset is not None: # Check if user provided a custom graph if args.user_graph is not None: print(f"[INFO] Using user-provided network from: {args.user_graph}") from flexynesis.data import read_user_graph + graph_df = read_user_graph(args.user_graph) print(f"[INFO] Loaded {len(graph_df)} interactions from user graph") else: # Fallback to STRING DB (default behavior) - print("[INFO] No user graph provided. Using STRING DB network (default)") - obj = STRING(os.path.join(args.data_path, '_'.join(['processed', args.prefix])), - args.string_organism, args.string_node_name) + print( + "[INFO] No user graph provided. Using STRING DB network (default)" + ) + obj = STRING( + os.path.join(args.data_path, "_".join(["processed", args.prefix])), + args.string_organism, + args.string_node_name, + ) graph_df = obj.graph_df # Use data_types order from args for consistent modality ordering - modality_order = args.data_types.split(',') + modality_order = args.data_types.split(",") # Overlay the graph onto datasets - train_dataset = MultiOmicDatasetNW(train_dataset, graph_df, modality_order=modality_order) + train_dataset = MultiOmicDatasetNW( + train_dataset, graph_df, modality_order=modality_order + ) train_dataset.print_stats() - test_dataset = MultiOmicDatasetNW(test_dataset, graph_df, modality_order=modality_order) + test_dataset = MultiOmicDatasetNW( + test_dataset, graph_df, modality_order=modality_order + ) # Training only happens when train_dataset exists (not in inference mode) if train_dataset is not None: # feature logs feature_logs = data_importer.feature_logs for key in feature_logs.keys(): - feature_logs[key].to_csv(os.path.join(args.outdir, '.'.join([args.prefix, 'feature_logs', key, 'csv'])), - header=True, index=False) + feature_logs[key].to_csv( + os.path.join( + args.outdir, + ".".join([args.prefix, "feature_logs", key, "csv"]), + ), + header=True, + index=False, + ) # tuner tuner = HyperparameterTuning( dataset=train_dataset, model_class=model_class, - target_variables=args.target_variables.strip().split(',') if args.target_variables is not None else [], + target_variables=( + args.target_variables.strip().split(",") + if args.target_variables is not None + else [] + ), batch_variables=None, surv_event_var=args.surv_event_var, surv_time_var=args.surv_time_var, config_name=config_name, config_path=args.config_path, n_iter=int(args.hpo_iter), - use_loss_weighting=args.use_loss_weighting == 'True', + use_loss_weighting=args.use_loss_weighting == "True", val_size=args.val_size, use_cv=args.use_cv, early_stop_patience=int(args.early_stop_patience), @@ -793,22 +1285,24 @@ def main(): gnn_conv_type=gnn_conv_type, input_layers=input_layers, output_layers=output_layers, - num_workers=args.num_workers + num_workers=args.num_workers, ) # do a hyperparameter search training multiple models and get the best configuration - t1 = time.time() model, best_params = tuner.perform_tuning(hpo_patience=args.hpo_patience) - hpo_time = time.time() - t1 - hpo_system_ram = process.memory_info().rss # if fine-tuning is enabled; fine tune the model on a portion of test samples if args.finetuning_samples > 0: finetuneSampleN = args.finetuning_samples - print("[INFO] Finetuning the model on ", finetuneSampleN, "test samples") + print( + "[INFO] Finetuning the model on ", + finetuneSampleN, + "test samples", + ) # split test dataset into finetuning and holdout datasets all_indices = range(len(test_dataset)) import random as _random + finetune_indices = _random.sample(list(all_indices), finetuneSampleN) holdout_indices = list(set(all_indices) - set(finetune_indices)) finetune_dataset = test_dataset.subset(finetune_indices) @@ -827,105 +1321,211 @@ def main(): print("[INFO] Extracting sample embeddings") if train_dataset is not None: embeddings_train = model.transform(train_dataset) - embeddings_train.to_csv(os.path.join(args.outdir, '.'.join([args.prefix, 'embeddings_train.csv'])), header=True) + embeddings_train.to_csv( + os.path.join(args.outdir, ".".join([args.prefix, "embeddings_train.csv"])), + header=True, + ) embeddings_test = model.transform(test_dataset) - embeddings_test.to_csv(os.path.join(args.outdir, '.'.join([args.prefix, 'embeddings_test.csv'])), header=True) + embeddings_test.to_csv( + os.path.join(args.outdir, ".".join([args.prefix, "embeddings_test.csv"])), + header=True, + ) # evaluate predictions; (if any supervised learning happened) if any([args.target_variables, args.surv_event_var]): - if not args.disable_marker_finding and train_dataset is not None: # unless marker discovery is disabled + if ( + not args.disable_marker_finding and train_dataset is not None + ): # unless marker discovery is disabled # compute feature importance values - if args.feature_importance_method == 'Both': - explainers = ['IntegratedGradients', 'GradientShap'] + if args.feature_importance_method == "Both": + explainers = ["IntegratedGradients", "GradientShap"] else: explainers = [args.feature_importance_method] for explainer in explainers: - print("[INFO] Computing variable importance scores using explainer:", explainer) + print( + "[INFO] Computing variable importance scores using explainer: ", + explainer, + ) for var in model.target_variables: - model.compute_feature_importance(train_dataset, var, steps_or_samples=25, method=explainer) - import pandas as pd - df_imp = pd.concat([model.feature_importances[x] for x in model.target_variables], ignore_index=True) - df_imp['explainer'] = explainer - df_imp.to_csv(os.path.join(args.outdir, '.'.join([args.prefix, 'feature_importance', explainer, 'csv'])), header=True, index=False) + model.compute_feature_importance( + train_dataset, + var, + steps_or_samples=25, + method=explainer, + ) + df_imp = pd.concat( + [ + model.feature_importances[x] + for x in model.target_variables + ], + ignore_index=True, + ) + df_imp["explainer"] = explainer + df_imp.to_csv( + os.path.join( + args.outdir, + ".".join( + [ + args.prefix, + "feature_importance", + explainer, + "csv", + ] + ), + ), + header=True, + index=False, + ) # print known/predicted labels if train_dataset is not None: - predicted_labels = pd.concat([ - get_predicted_labels(model.predict(train_dataset), train_dataset, 'train', args.model_class), - get_predicted_labels(model.predict(test_dataset), test_dataset, 'test', args.model_class) - ], ignore_index=True) + predicted_labels = pd.concat( + [ + get_predicted_labels( + model.predict(train_dataset), + train_dataset, + "train", + args.model_class, + ), + get_predicted_labels( + model.predict(test_dataset), + test_dataset, + "test", + args.model_class, + ), + ], + ignore_index=True, + ) else: - predicted_labels = get_predicted_labels(model.predict(test_dataset), test_dataset, 'test', args.model_class) - predicted_labels.to_csv(os.path.join(args.outdir, '.'.join([args.prefix, 'predicted_labels.csv'])), header=True, index=False) + predicted_labels = get_predicted_labels( + model.predict(test_dataset), + test_dataset, + "test", + args.model_class, + ) + predicted_labels.to_csv( + os.path.join(args.outdir, ".".join([args.prefix, "predicted_labels.csv"])), + header=True, + index=False, + ) print("[INFO] Computing model evaluation metrics") metrics_df = evaluate_wrapper( - args.model_class, model.predict(test_dataset), test_dataset, + args.model_class, + model.predict(test_dataset), + test_dataset, surv_event_var=model.surv_event_var, - surv_time_var=model.surv_time_var + surv_time_var=model.surv_time_var, + ) + metrics_df.to_csv( + os.path.join(args.outdir, ".".join([args.prefix, "stats.csv"])), + header=True, + index=False, ) - metrics_df.to_csv(os.path.join(args.outdir, '.'.join([args.prefix, 'stats.csv'])), header=True, index=False) # for architectures with decoders; print decoded output layers - if args.model_class == 'CrossModalPred': + if args.model_class == "CrossModalPred": print("[INFO] Printing decoded output layers") # In inference mode, only decode test dataset if train_dataset is not None: output_layers_train = model.decode(train_dataset) for layer in output_layers_train.keys(): output_layers_train[layer].to_csv( - os.path.join(args.outdir, '.'.join([args.prefix, 'train_decoded', layer, 'csv'])), - header=True + os.path.join( + args.outdir, + ".".join([args.prefix, "train_decoded", layer, "csv"]), + ), + header=True, ) output_layers_test = model.decode(test_dataset) for layer in output_layers_test.keys(): output_layers_test[layer].to_csv( - os.path.join(args.outdir, '.'.join([args.prefix, 'test_decoded', layer, 'csv'])), - header=True + os.path.join( + args.outdir, + ".".join([args.prefix, "test_decoded", layer, "csv"]), + ), + header=True, ) # evaluate off-the-shelf methods on the main target variable if args.evaluate_baseline_performance: - print("[INFO] Computing off-the-shelf method performance on first target variable:",model.target_variables[0]) + print( + "[INFO] Computing off-the-shelf method performance on first target variable:", + model.target_variables[0], + ) var = model.target_variables[0] metrics = pd.DataFrame() # in the case when GNNEarly was used, the we use the initial multiomicdataset for train/test # because GNNEarly requires a modified dataset structure to fit the networks (temporary solution) - if args.model_class == 'GNN': - train = getattr(train_dataset, 'multiomic_dataset', train_dataset) - test = getattr(test_dataset, 'multiomic_dataset', test_dataset) + if args.model_class == "GNN": + train = getattr(train_dataset, "multiomic_dataset", train_dataset) + test = getattr(test_dataset, "multiomic_dataset", test_dataset) else: train = train_dataset test = test_dataset if var != model.surv_event_var: - metrics, predictions = evaluate_baseline_performance(train, test, - variable_name = var, - methods = ['RandomForest', 'SVM', 'XGBoost'], - n_folds = 5, - n_jobs = int(args.threads)) - predictions.to_csv(os.path.join(args.outdir, '.'.join([args.prefix, 'baseline.predicted_labels.csv'])), header=True, index=False) + metrics, predictions = evaluate_baseline_performance( + train, + test, + variable_name=var, + methods=["RandomForest", "SVM", "XGBoost"], + n_folds=5, + n_jobs=int(args.threads), + ) + predictions.to_csv( + os.path.join( + args.outdir, + ".".join([args.prefix, "baseline.predicted_labels.csv"]), + ), + header=True, + index=False, + ) if model.surv_event_var and model.surv_time_var: - print("[INFO] Computing off-the-shelf method performance on survival variable:",model.surv_time_var) - metrics_baseline_survival = evaluate_baseline_survival_performance(train, test, - model.surv_time_var, - model.surv_event_var, - n_folds = 5, - n_jobs = int(args.threads)) - metrics = pd.concat([metrics, metrics_baseline_survival], axis = 0, ignore_index = True) + print( + "[INFO] Computing off-the-shelf method performance on survival " + "variable:", + model.surv_time_var, + ) + metrics_baseline_survival = evaluate_baseline_survival_performance( + train, + test, + model.surv_time_var, + model.surv_event_var, + n_folds=5, + n_jobs=int(args.threads), + ) + metrics = pd.concat( + [metrics, metrics_baseline_survival], axis=0, ignore_index=True + ) if not metrics.empty: - metrics.to_csv(os.path.join(args.outdir, '.'.join([args.prefix, 'baseline.stats.csv'])), header=True, index=False) + metrics.to_csv( + os.path.join( + args.outdir, ".".join([args.prefix, "baseline.stats.csv"]) + ), + header=True, + index=False, + ) # save the trained model in file (skip in inference mode) if not in_infer: if not args.safetensors: - torch.save(model, os.path.join(args.outdir, '.'.join([args.prefix, 'final_model.pth']))) + torch.save( + model, + os.path.join(args.outdir, ".".join([args.prefix, "final_model.pth"])), + ) else: - save_file(model.state_dict(), os.path.join(args.outdir, '.'.join([args.prefix, 'final_model.safetensors']))) + save_file( + model.state_dict(), + os.path.join( + args.outdir, + ".".join([args.prefix, "final_model.safetensors"]), + ), + ) # save model config as JSON config = { "model_class": model.__class__.__name__, @@ -933,50 +1533,89 @@ def main(): } # Common attributes to save common_attrs = [ - 'input_dims', 'layers', 'input_layers', 'output_layers', - 'device_type', 'target_variables', - 'surv_event_var', 'surv_time_var', - 'config', 'current_epoch', 'num_layers' + "input_dims", + "layers", + "input_layers", + "output_layers", + "device_type", + "target_variables", + "surv_event_var", + "surv_time_var", + "config", + "current_epoch", + "num_layers", ] for attr in common_attrs: if hasattr(model, attr): config[attr] = getattr(model, attr) - if hasattr(model, 'layers'): - config['num_layers'] = len(model.layers) + if hasattr(model, "layers"): + config["num_layers"] = len(model.layers) - if hasattr(model, 'config'): + if hasattr(model, "config"): model_specific_config = model.config config.update(model_specific_config) - with open(os.path.join(args.outdir, '.'.join([args.prefix, 'final_model_config.json'])), 'w') as f: + with open( + os.path.join( + args.outdir, ".".join([args.prefix, "final_model_config.json"]) + ), + "w", + ) as f: json.dump(config, f, indent=2, default=str) # --- write inference artifacts joblib (auto-generated after training) --- if train_dataset is not None: # Only save artifacts in training mode try: import joblib + artifacts = { - 'schema_version': 1, - 'data_types': list(data_importer.train_features.keys()) if hasattr(data_importer, 'train_features') else args.data_types.split(','), # Use actual data structure keys (e.g. ['all'] for early fusion) - 'original_modalities': args.data_types.split(','), # Original modalities from CLI before concatenation - 'target_variables': args.target_variables.split(',') if args.target_variables else [], - 'feature_lists': data_importer.train_features if hasattr(data_importer, 'train_features') else {}, - 'transforms': data_importer.scalers if hasattr(data_importer, 'scalers') else {}, - 'label_encoders': data_importer.label_encoders if hasattr(data_importer, 'label_encoders') else {}, - 'covariate_vars': covariates if covariates is not None else [], # Store covariate variable names - 'join_key': args.join_key, - 'string_organism': args.string_organism, - 'string_node_name': args.string_node_name, + "schema_version": 1, + "data_types": ( + list(data_importer.train_features.keys()) + if hasattr(data_importer, "train_features") + else args.data_types.split(",") + ), # Use actual data structure keys (e.g. ['all'] for early + # fusion + "original_modalities": args.data_types.split( + "," + ), # Original modalities from CLI before concatenation + "target_variables": ( + args.target_variables.split(",") if args.target_variables else [] + ), + "feature_lists": ( + data_importer.train_features + if hasattr(data_importer, "train_features") + else {} + ), + "transforms": ( + data_importer.scalers if hasattr(data_importer, "scalers") else {} + ), + "label_encoders": ( + data_importer.label_encoders + if hasattr(data_importer, "label_encoders") + else {} + ), + "covariate_vars": ( + covariates if covariates is not None else [] + ), # Store covariate variable names + "join_key": args.join_key, + "string_organism": args.string_organism, + "string_node_name": args.string_node_name, } if not args.safetensors: - joblib_path = os.path.join(args.outdir, '.'.join([args.prefix, 'artifacts.joblib'])) + joblib_path = os.path.join( + args.outdir, ".".join([args.prefix, "artifacts.joblib"]) + ) joblib.dump(artifacts, joblib_path) - print(f'[INFO] Wrote inference artifacts to {joblib_path}') + print(f"[INFO] Wrote inference artifacts to {joblib_path}") elif args.safetensors: - import numpy as np - from sklearn.preprocessing import LabelEncoder, OrdinalEncoder, StandardScaler + import numpy as np # noqa: F401 + from sklearn.preprocessing import LabelEncoder # noqa: F401 + from sklearn.preprocessing import (OrdinalEncoder, + StandardScaler) + json_ready = { "schema_version": artifacts["schema_version"], "data_types": artifacts["data_types"], @@ -986,7 +1625,10 @@ def main(): "join_key": artifacts["join_key"], "string_organism": artifacts["string_organism"], "string_node_name": artifacts["string_node_name"], - "feature_lists": {modality: list(features) for modality, features in artifacts["feature_lists"].items()}, + "feature_lists": { + modality: list(features) + for modality, features in artifacts["feature_lists"].items() + }, "transforms": {}, "label_encoders": {}, } @@ -997,16 +1639,34 @@ def main(): json_ready["transforms"][modality] = None continue if not isinstance(scaler, StandardScaler): - raise ValueError(f"Unsupported scaler type for modality '{modality}': {type(scaler).__name__}.") - scaler_dict = {"type": "StandardScaler", "with_mean": scaler.with_mean, "with_std": scaler.with_std} - if hasattr(scaler, "mean_"): scaler_dict["mean"] = scaler.mean_.tolist() - if hasattr(scaler, "scale_"): scaler_dict["scale"] = scaler.scale_.tolist() - if hasattr(scaler, "var_"): scaler_dict["var"] = scaler.var_.tolist() - if hasattr(scaler, "n_features_in_"): scaler_dict["n_features_in"] = int(scaler.n_features_in_) - if hasattr(scaler, "feature_names_in_"): scaler_dict["feature_names_in"] = scaler.feature_names_in_.tolist() + raise ValueError( + f"Unsupported scaler type for modality '{modality}': " + f"{type(scaler).__name__}." + ) + scaler_dict = { + "type": "StandardScaler", + "with_mean": scaler.with_mean, + "with_std": scaler.with_std, + } + if hasattr(scaler, "mean_"): + scaler_dict["mean"] = scaler.mean_.tolist() + if hasattr(scaler, "scale_"): + scaler_dict["scale"] = scaler.scale_.tolist() + if hasattr(scaler, "var_"): + scaler_dict["var"] = scaler.var_.tolist() + if hasattr(scaler, "n_features_in_"): + scaler_dict["n_features_in"] = int(scaler.n_features_in_) + if hasattr(scaler, "feature_names_in_"): + scaler_dict["feature_names_in"] = ( + scaler.feature_names_in_.tolist() + ) if hasattr(scaler, "n_samples_seen_"): n_samples = scaler.n_samples_seen_ - scaler_dict["n_samples_seen"] = n_samples.tolist() if isinstance(n_samples, np.ndarray) else int(n_samples) + scaler_dict["n_samples_seen"] = ( + n_samples.tolist() + if isinstance(n_samples, np.ndarray) + else int(n_samples) + ) json_ready["transforms"][modality] = scaler_dict # Convert LabelEncoder/OrdinalEncoder objects @@ -1015,31 +1675,55 @@ def main(): json_ready["label_encoders"][variable] = None continue if isinstance(encoder, LabelEncoder): - encoder_dict = {"type": "LabelEncoder", "classes": encoder.classes_.tolist()} + encoder_dict = { + "type": "LabelEncoder", + "classes": encoder.classes_.tolist(), + } elif isinstance(encoder, OrdinalEncoder): encoder_dict = { "type": "OrdinalEncoder", - "categories": [cat.tolist() for cat in encoder.categories_], + "categories": [ + cat.tolist() for cat in encoder.categories_ + ], "handle_unknown": encoder.handle_unknown, "unknown_value": encoder.unknown_value, } if hasattr(encoder, "encoded_missing_value"): val = encoder.encoded_missing_value - encoder_dict["encoded_missing_value"] = "__NaN__" if isinstance(val, float) and np.isnan(val) else val - if hasattr(encoder, "n_features_in_"): encoder_dict["n_features_in"] = int(encoder.n_features_in_) - if hasattr(encoder, "feature_names_in_"): encoder_dict["feature_names_in"] = encoder.feature_names_in_.tolist() + encoder_dict["encoded_missing_value"] = ( + "__NaN__" + if isinstance(val, float) and np.isnan(val) + else val + ) + if hasattr(encoder, "n_features_in_"): + encoder_dict["n_features_in"] = int(encoder.n_features_in_) + if hasattr(encoder, "feature_names_in_"): + encoder_dict["feature_names_in"] = ( + encoder.feature_names_in_.tolist() + ) if hasattr(encoder, "_missing_indices"): missing_indices = encoder._missing_indices - encoder_dict["_missing_indices"] = {str(k): v for k, v in missing_indices.items()} if isinstance(missing_indices, dict) else missing_indices - if hasattr(encoder, "_infrequent_enabled"): encoder_dict["_infrequent_enabled"] = encoder._infrequent_enabled + encoder_dict["_missing_indices"] = ( + {str(k): v for k, v in missing_indices.items()} + if isinstance(missing_indices, dict) + else missing_indices + ) + if hasattr(encoder, "_infrequent_enabled"): + encoder_dict["_infrequent_enabled"] = ( + encoder._infrequent_enabled + ) else: - raise ValueError(f"Unknown encoder type: {type(encoder).__name__}") + raise ValueError( + f"Unknown encoder type: {type(encoder).__name__}" + ) json_ready["label_encoders"][variable] = encoder_dict - json_path = os.path.join(args.outdir, '.'.join([args.prefix, 'artifacts.json'])) + json_path = os.path.join( + args.outdir, ".".join([args.prefix, "artifacts.json"]) + ) with open(json_path, "w") as f: json.dump(json_ready, f, indent=2) - print(f'[INFO] Wrote inference artifacts to {json_path}') + print(f"[INFO] Wrote inference artifacts to {json_path}") except Exception as e: - print(f'[WARN] Could not write inference artifacts: {e}') + print(f"[WARN] Could not write inference artifacts: {e}") diff --git a/flexynesis/config.py b/flexynesis/config.py index 817347bd..fda0444a 100644 --- a/flexynesis/config.py +++ b/flexynesis/config.py @@ -1,44 +1,52 @@ # config.py -from skopt.space import Integer, Categorical, Real +from skopt.space import Categorical, Integer, Real epochs = [500] search_spaces = { - 'DirectPred': [ - Integer(16, 128, name='latent_dim'), - Real(0.2, 0.5, name='hidden_dim_factor'), # relative size of the hidden_dim w.r.t input_dim - Real(0.0001, 0.01, prior='log-uniform', name='lr'), - Integer(8, 32, name='supervisor_hidden_dim'), - Categorical(epochs, name='epochs') - ], - 'supervised_vae': [ - Integer(16, 128, name='latent_dim'), - Real(0.2, 0.5, name='hidden_dim_factor'), # relative size of the hidden_dim w.r.t input_dim - Integer(8, 32, name='supervisor_hidden_dim'), - Real(0.0001, 0.01, prior='log-uniform', name='lr'), - Categorical(epochs, name='epochs') + "DirectPred": [ + Integer(16, 128, name="latent_dim"), + Real( + 0.2, 0.5, name="hidden_dim_factor" + ), # relative size of the hidden_dim w.r.t input_dim + Real(0.0001, 0.01, prior="log-uniform", name="lr"), + Integer(8, 32, name="supervisor_hidden_dim"), + Categorical(epochs, name="epochs"), ], - 'CrossModalPred': [ - Integer(16, 128, name='latent_dim'), - Real(0.2, 0.5, name='hidden_dim_factor'), # relative size of the hidden_dim w.r.t input_dim - Integer(8, 32, name='supervisor_hidden_dim'), - Real(0.0001, 0.01, prior='log-uniform', name='lr'), - Categorical(epochs, name='epochs') + "supervised_vae": [ + Integer(16, 128, name="latent_dim"), + Real( + 0.2, 0.5, name="hidden_dim_factor" + ), # relative size of the hidden_dim w.r.t input_dim + Integer(8, 32, name="supervisor_hidden_dim"), + Real(0.0001, 0.01, prior="log-uniform", name="lr"), + Categorical(epochs, name="epochs"), ], - 'MultiTripletNetwork': [ - Integer(16, 128, name='latent_dim'), - Real(0.2, 0.5, name='hidden_dim_factor'), # relative size of the hidden_dim w.r.t input_dim - Integer(8, 32, name='supervisor_hidden_dim'), - Real(0.0001, 0.01, prior='log-uniform', name='lr'), - Categorical(epochs, name='epochs') + "CrossModalPred": [ + Integer(16, 128, name="latent_dim"), + Real( + 0.2, 0.5, name="hidden_dim_factor" + ), # relative size of the hidden_dim w.r.t input_dim + Integer(8, 32, name="supervisor_hidden_dim"), + Real(0.0001, 0.01, prior="log-uniform", name="lr"), + Categorical(epochs, name="epochs"), + ], + "MultiTripletNetwork": [ + Integer(16, 128, name="latent_dim"), + Real( + 0.2, 0.5, name="hidden_dim_factor" + ), # relative size of the hidden_dim w.r.t input_dim + Integer(8, 32, name="supervisor_hidden_dim"), + Real(0.0001, 0.01, prior="log-uniform", name="lr"), + Categorical(epochs, name="epochs"), + ], + "GNN": [ + Integer(16, 128, name="latent_dim"), + Integer(4, 32, name="node_embedding_dim"), # node embedding dimensions + Integer(1, 4, name="num_convs"), # number of convolutional layers + Real(0.0001, 0.01, prior="log-uniform", name="lr"), + Integer(8, 32, name="supervisor_hidden_dim"), + Categorical(epochs, name="epochs"), + Categorical(["relu"], name="activation"), ], - 'GNN': [ - Integer(16, 128, name='latent_dim'), - Integer(4, 32, name='node_embedding_dim'), # node embedding dimensions - Integer(1, 4, name='num_convs'), # number of convolutional layers - Real(0.0001, 0.01, prior='log-uniform', name='lr'), - Integer(8, 32, name='supervisor_hidden_dim'), - Categorical(epochs, name='epochs'), - Categorical(['relu'], name="activation") - ] } diff --git a/flexynesis/data.py b/flexynesis/data.py index 52545a6b..e6f12831 100644 --- a/flexynesis/data.py +++ b/flexynesis/data.py @@ -1,26 +1,21 @@ -from torch.utils.data import Dataset, DataLoader -from torch_geometric.data import download_url, extract_gz -from torch_geometric.data import Dataset as PYGDataset +import os +import tempfile +from functools import reduce +from itertools import chain +from pathlib import Path import numpy as np import pandas as pd -from functools import reduce import torch -import os -import shutil - -from pathlib import Path from filelock import FileLock from platformdirs import user_cache_dir +from sklearn.preprocessing import MinMaxScaler, OrdinalEncoder, StandardScaler +from torch.utils.data import Dataset +from torch_geometric.data import Dataset as PYGDataset +from torch_geometric.data import download_url, extract_gz -from tqdm import tqdm - - -from sklearn.preprocessing import OrdinalEncoder, StandardScaler, MinMaxScaler, PowerTransformer from .feature_selection import filter_by_laplacian -from .utils import get_variable_types, create_covariate_matrix - -from itertools import chain +from .utils import create_covariate_matrix, get_variable_types # convert_to_labels: if true, given a numeric list, convert to binary labels by median value @@ -41,7 +36,8 @@ class DataImporter: variance_threshold (float): The variance threshold for removing low-variance features. na_threshold (float): The threshold for removing features with too many NA values. string_organism (int): STRING organism (species) id (default: 9606 (human)). - string_node_name (str): The type of node names used in the graph. Available options: "gene_name", "gene_id" (default: "gene_name"). + string_node_name (str): The type of node names used in the graph. + Available options: "gene_name", "gene_id" (default: "gene_name"). Methods: import_data(): The primary method to orchestrate the data import and preprocessing workflow. It follows these steps: @@ -69,9 +65,10 @@ class DataImporter: Prepares the data for model input by cleaning, filtering, and selecting features and samples. select_features(dat): - Implements an unsupervised feature selection by ranking features by the Laplacian score, keeping the features at - the top percentile range and removing highly redundant features (optional) based on a correlation threshold, - while keeping a minimum number of top features as requested by the user. + Implements an unsupervised feature selection by ranking features by the Laplacian + score, keeping the features at the top percentile range and removing highly + redundant features (optional) based on a correlation threshold, while keeping a + minimum number of top features as requested by the user. harmonize(dat1, dat2): Aligns the feature sets of two datasets (e.g., training and testing) to have the same features. @@ -92,8 +89,22 @@ class DataImporter: Encodes categorical labels in the annotation dataframe. """ - def __init__(self, path, data_types, covariates = None, processed_dir="processed", log_transform = False, concatenate = False, restrict_to_features = None, min_features=None, - top_percentile=20, correlation_threshold = 0.9, variance_threshold=0.01, na_threshold=0.1, downsample=0): + def __init__( + self, + path, + data_types, + covariates=None, + processed_dir="processed", + log_transform=False, + concatenate=False, + restrict_to_features=None, + min_features=None, + top_percentile=20, + correlation_threshold=0.9, + variance_threshold=0.01, + na_threshold=0.1, + downsample=0, + ): self.path = path self.data_types = data_types self.processed_dir = os.path.join(self.path, processed_dir) @@ -105,7 +116,7 @@ def __init__(self, path, data_types, covariates = None, processed_dir="processed self.na_threshold = na_threshold self.log_transform = log_transform # Initialize a dictionary to store the label encoders - self.encoders = {} # used if labels are categorical + self.encoders = {} # used if labels are categorical # initialize data scalers self.scalers = None # initialize data transformers @@ -124,8 +135,12 @@ def __init__(self, path, data_types, covariates = None, processed_dir="processed self.feature_logs = {} # NEW: Storage for inference mode artifacts - self.train_features = {} # Stores final feature names per modality after selection - self.label_encoders = {} # Stores fitted label encoders for categorical variables + self.train_features = ( + {} + ) # Stores final feature names per modality after selection + self.label_encoders = ( + {} + ) # Stores fitted label encoders for categorical variables # Note: self.scalers already exists, so we'll use that! def get_user_features(self): @@ -136,9 +151,11 @@ def get_user_features(self): if not os.path.isfile(self.restrict_to_features): raise FileNotFoundError(f"File not found: {self.restrict_to_features}") try: - with open(self.restrict_to_features, 'r') as fp: + with open(self.restrict_to_features, "r") as fp: # Read and process the file - feature_list = [x.strip() for x in fp.read().splitlines() if x.strip()] + feature_list = [ + x.strip() for x in fp.read().splitlines() if x.strip() + ] # Ensure uniqueness and assign self.restrict_to_features = np.unique(feature_list) except Exception as e: @@ -148,8 +165,8 @@ def get_user_features(self): def import_data(self): print("\n[INFO] ================= Importing Data =================") - training_path = os.path.join(self.path, 'train') - testing_path = os.path.join(self.path, 'test') + training_path = os.path.join(self.path, "train") + testing_path = os.path.join(self.path, "test") self.validate_data_folders(training_path, testing_path) @@ -158,7 +175,11 @@ def import_data(self): test_dat = self.read_data(testing_path) if self.downsample > 0: - print("[INFO] Randomly drawing",self.downsample,"samples for training") + print( + "[INFO] Randomly drawing", + self.downsample, + "samples for training", + ) train_dat = self.subsample(train_dat, self.downsample) if self.restrict_to_features is not None: @@ -169,8 +190,12 @@ def import_data(self): self.validate_input_data(train_dat, test_dat) # cleanup uninformative features/samples, subset annotation data, do feature selection on training data - train_dat, train_ann, train_samples, train_features = self.process_data(train_dat, split = 'train') - test_dat, test_ann, test_samples, test_features = self.process_data(test_dat, split = 'test') + train_dat, train_ann, train_samples, train_features = self.process_data( + train_dat, split="train" + ) + test_dat, test_ann, test_samples, test_features = self.process_data( + test_dat, split="test" + ) # harmonize feature sets in train/test train_dat, test_dat = self.harmonize(train_dat, test_dat) @@ -188,9 +213,16 @@ def import_data(self): # if covariates are defined, create a covariate matrix and add to the dictionary of data matrices if self.covariates: - print("[INFO] Attempting to create a covariate matrix for the covariates:",self.covariates) - train_dat['covariates'] = create_covariate_matrix(self.covariates, get_variable_types(train_ann), train_ann) - test_dat['covariates'] = create_covariate_matrix(self.covariates, get_variable_types(test_ann), test_ann) + print( + "[INFO] Attempting to create a covariate matrix for the covariates:", + self.covariates, + ) + train_dat["covariates"] = create_covariate_matrix( + self.covariates, get_variable_types(train_ann), train_ann + ) + test_dat["covariates"] = create_covariate_matrix( + self.covariates, get_variable_types(test_ann), test_ann + ) # harmonize again to match the covariate features train_dat, test_dat = self.harmonize(train_dat, test_dat) @@ -202,27 +234,51 @@ def import_data(self): if self.concatenate: # Use data_types order for consistent concatenation modality_order = self.data_types - training_dataset.dat = {'all': torch.cat([training_dataset.dat[x] for x in modality_order], dim = 1)} - training_dataset.features = {'all': list(chain(*[training_dataset.features[x] for x in modality_order]))} + training_dataset.dat = { + "all": torch.cat( + [training_dataset.dat[x] for x in modality_order], dim=1 + ) + } + training_dataset.features = { + "all": list( + chain(*[training_dataset.features[x] for x in modality_order]) + ) + } - testing_dataset.dat = {'all': torch.cat([testing_dataset.dat[x] for x in modality_order], dim = 1)} - testing_dataset.features = {'all': list(chain(*[testing_dataset.features[x] for x in modality_order]))} + testing_dataset.dat = { + "all": torch.cat( + [testing_dataset.dat[x] for x in modality_order], dim=1 + ) + } + testing_dataset.features = { + "all": list( + chain(*[testing_dataset.features[x] for x in modality_order]) + ) + } # Save final feature lists AFTER concatenation (for inference mode) self.train_features = training_dataset.features.copy() - - print("[INFO] Training Data Stats: ", training_dataset.get_dataset_stats()) + print( + "[INFO] Training Data Stats: ", + training_dataset.get_dataset_stats(), + ) print("[INFO] Test Data Stats: ", testing_dataset.get_dataset_stats()) print("[INFO] Merging Feature Logs...") logs = self.feature_logs - if 'select_features' in logs: - self.feature_logs = {x: pd.merge(logs['cleanup'][x], - logs['select_features'][x], - on = 'feature', how = 'outer', - suffixes=['_cleanup', '_laplacian']) for x in self.data_types} + if "select_features" in logs: + self.feature_logs = { + x: pd.merge( + logs["cleanup"][x], + logs["select_features"][x], + on="feature", + how="outer", + suffixes=["_cleanup", "_laplacian"], + ) + for x in self.data_types + } else: # Feature selection was skipped (top_percentile=0), so just use cleanup logs - self.feature_logs = logs['cleanup'] + self.feature_logs = logs["cleanup"] print("[INFO] Data import successful.") return training_dataset, testing_dataset @@ -232,19 +288,23 @@ def validate_data_folders(self, training_path, testing_path): training_files = set(os.listdir(training_path)) testing_files = set(os.listdir(testing_path)) - required_files = {'clin.csv'} | {f"{dt}.csv" for dt in self.data_types} + required_files = {"clin.csv"} | {f"{dt}.csv" for dt in self.data_types} if not required_files.issubset(training_files): missing_files = required_files - training_files - raise ValueError(f"Missing files in training folder: {', '.join(missing_files)}") + raise ValueError( + f"Missing files in training folder: {', '.join(missing_files)}" + ) if not required_files.issubset(testing_files): missing_files = required_files - testing_files - raise ValueError(f"Missing files in testing folder: {', '.join(missing_files)}") + raise ValueError( + f"Missing files in testing folder: {', '.join(missing_files)}" + ) def read_data(self, folder_path): data = {} - required_files = {'clin.csv'} | {f"{dt}.csv" for dt in self.data_types} + required_files = {"clin.csv"} | {f"{dt}.csv" for dt in self.data_types} print("\n[INFO] ----------------- Reading Data ----------------- ") for file in required_files: file_path = os.path.join(folder_path, file) @@ -255,36 +315,41 @@ def read_data(self, folder_path): # randomly draw N samples; return subset of dat (output of read_data) def subsample(self, dat, N): - clin = dat['clin'].sample(N) + clin = dat["clin"].sample(N) dat_sub = {x: dat[x][clin.index] for x in self.data_types} - dat_sub['clin'] = clin + dat_sub["clin"] = clin return dat_sub - def filter_by_features(self, dat, features): """ If the user has provided list of features to restrict the analysis to, subset train/test data to only include those features """ dat_filtered = { - key: df if key == "clin" else df.loc[df.index.intersection(features)] + key: (df if key == "clin" else df.loc[df.index.intersection(features)]) for key, df in dat.items() } - print("[INFO] The initial features are filtered to include user-provided features only") + print( + "[INFO] The initial features are filtered to include user-provided features only" + ) for key, df in dat_filtered.items(): remaining_features = len(df.index) - print(f"In layer '{key}', {remaining_features} features are remaining after filtering.") + print( + f"In layer '{key}', {remaining_features} features are remaining after filtering." + ) return dat_filtered - def process_data(self, data, split = 'train'): - print(f"\n[INFO] ----------------- Processing Data ({split}) ----------------- ") + def process_data(self, data, split="train"): + print( + f"\n[INFO] ----------------- Processing Data ({split}) ----------------- " + ) # remove uninformative features and samples with no information (from data matrices) dat = self.cleanup_data({x: data[x] for x in self.data_types}) - ann = data['clin'] + ann = data["clin"] dat, ann, samples = self.get_labels(dat, ann) # do feature selection: only applied to training data - if split == 'train': + if split == "train": if self.top_percentile: dat = self.select_features(dat) features = {x: dat[x].index for x in dat.keys()} @@ -295,10 +360,10 @@ def cleanup_data(self, df_dict): cleaned_dfs = {} sample_masks = [] - feature_logs = {} # keep track of feature variation/NA value scores + feature_logs = {} # keep track of feature variation/NA value scores # First pass: remove near-zero-variation features and create masks for informative samples for key, df in df_dict.items(): - print("\n[INFO] working on layer: ",key) + print("\n[INFO] working on layer: ", key) original_features_count = df.shape[0] # Compute variances and NA percentages for each feature in the DataFrame @@ -306,13 +371,29 @@ def cleanup_data(self, df_dict): na_percentages = df.isna().mean(axis=1) # Combine variances and NA percentages into a single DataFrame for logging - log_df = pd.DataFrame({ 'feature': df.index, 'na_percent': na_percentages, 'variance': feature_variances, 'selected': False}) + log_df = pd.DataFrame( + { + "feature": df.index, + "na_percent": na_percentages, + "variance": feature_variances, + "selected": False, + } + ) # Filter based on both variance and NA percentage thresholds # Identify features that meet both criteria - df = df.loc[(feature_variances >= feature_variances.quantile(self.variance_threshold)) & (na_percentages < self.na_threshold)] + df = df.loc[ + ( + feature_variances + >= feature_variances.quantile(self.variance_threshold) + ) + & (na_percentages < self.na_threshold) + ] # set selected features to True - log_df['selected'] = (log_df['variance'] >= feature_variances.quantile(self.variance_threshold)) & (log_df['na_percent'] < self.na_threshold) + log_df["selected"] = ( + log_df["variance"] + >= feature_variances.quantile(self.variance_threshold) + ) & (log_df["na_percent"] < self.na_threshold) feature_logs[key] = log_df # Step 3: Fill NA values with the median of the feature @@ -320,7 +401,12 @@ def cleanup_data(self, df_dict): if np.sum(df.isna().sum()) > 0: missing_rows = df.isna().any(axis=1) - print("[INFO] Imputing NA values to median of features, affected # of cells in the matrix", np.sum(df.isna().sum()), " # of rows:",sum(missing_rows)) + print( + "[INFO] Imputing NA values to median of features, affected # of cells in the matrix", + np.sum(df.isna().sum()), + " # of rows:", + sum(missing_rows), + ) # Calculate medians for each 'column' (originally rows) and fill NAs # Note: After transposition, operations are more efficient @@ -329,16 +415,20 @@ def cleanup_data(self, df_dict): df_T.fillna(medians_T, inplace=True) df = df_T.T - print("[INFO] Number of NA values: ",np.sum(df.isna().sum())) + print("[INFO] Number of NA values: ", np.sum(df.isna().sum())) removed_features_count = original_features_count - df.shape[0] - print(f"[INFO] DataFrame {key} - Removed {removed_features_count} features.") + print( + f"[INFO] DataFrame {key} - Removed {removed_features_count} features." + ) # Step 2: Create masks for informative samples # Compute standard deviation of samples (along columns) sample_stdevs = df.std(axis=0) # Create mask for samples that do not have std dev of 0 or NaN - mask = np.logical_and(sample_stdevs != 0, np.logical_not(np.isnan(sample_stdevs))) + mask = np.logical_and( + sample_stdevs != 0, np.logical_not(np.isnan(sample_stdevs)) + ) sample_masks.append(mask) cleaned_dfs[key] = df @@ -351,15 +441,26 @@ def cleanup_data(self, df_dict): original_samples_count = cleaned_dfs[key].shape[1] cleaned_dfs[key] = cleaned_dfs[key].loc[:, common_mask] removed_samples_count = original_samples_count - cleaned_dfs[key].shape[1] - print(f"[INFO] DataFrame {key} - Removed {removed_samples_count} samples ({removed_samples_count / original_samples_count * 100:.2f}%).") + removed_samples_pct = ( + removed_samples_count / original_samples_count * 100 + ) + print( + f"[INFO] DataFrame {key} - Removed {removed_samples_count} samples " + f"({removed_samples_pct:.2f}%)." + ) # update feature logs from this process - self.feature_logs['cleanup'] = feature_logs + self.feature_logs["cleanup"] = feature_logs return cleaned_dfs def get_labels(self, dat, ann): # subset samples and reorder annotations for the samples - samples = list(reduce(set.intersection, [set(item) for item in [dat[x].columns for x in dat.keys()]])) + samples = list( + reduce( + set.intersection, + [set(item) for item in [dat[x].columns for x in dat.keys()]], + ) + ) samples = list(set(ann.index).intersection(samples)) dat = {x: dat[x][samples] for x in dat.keys()} ann = ann.loc[samples] @@ -367,18 +468,28 @@ def get_labels(self, dat, ann): # unsupervised feature selection using laplacian score and correlation filters (optional) def select_features(self, dat): - counts = {x: max(int(dat[x].shape[0] * self.top_percentile / 100), self.min_features) for x in dat.keys()} + counts = { + x: max( + int(dat[x].shape[0] * self.top_percentile / 100), + self.min_features, + ) + for x in dat.keys() + } dat_filtered = {} - feature_logs = {} # feature log for each layer + feature_logs = {} # feature log for each layer for layer in dat.keys(): # filter features in the layer and keep a log of filtering process; notice we provide a transposed matrix - X_filt, log_df = filter_by_laplacian(X = dat[layer].T, layer = layer, - topN=counts[layer], correlation_threshold = self.correlation_threshold) - dat_filtered[layer] = X_filt.T # transpose after laplacian filtering again + X_filt, log_df = filter_by_laplacian( + X=dat[layer].T, + layer=layer, + topN=counts[layer], + correlation_threshold=self.correlation_threshold, + ) + dat_filtered[layer] = X_filt.T # transpose after laplacian filtering again # Features will be stored after concatenation feature_logs[layer] = log_df # update main feature logs with events from this function - self.feature_logs['select_features'] = feature_logs + self.feature_logs["select_features"] = feature_logs return dat_filtered def harmonize(self, dat1, dat2): @@ -386,7 +497,9 @@ def harmonize(self, dat1, dat2): # common data layers common_layers = dat1.keys() & dat2.keys() # Get common features - common_features = {x: dat1[x].index.intersection(dat2[x].index) for x in common_layers} + common_features = { + x: dat1[x].index.intersection(dat2[x].index) for x in common_layers + } # Subset both datasets to only include common features dat1 = {x: dat1[x].loc[common_features[x]] for x in common_layers} dat2 = {x: dat2[x].loc[common_features[x]] for x in common_layers} @@ -411,10 +524,14 @@ def normalize_data(self, data, scaler_type="standard", fit=True): else: raise ValueError("Invalid scaler_type. Choose 'standard' or 'min_max'.") - normalized_data = {x: pd.DataFrame(self.scalers[x].transform(data[x].T), - index=data[x].columns, - columns=data[x].index).T - for x in data.keys()} + normalized_data = { + x: pd.DataFrame( + self.scalers[x].transform(data[x].T), + index=data[x].columns, + columns=data[x].index, + ).T + for x in data.keys() + } return normalized_data def get_torch_dataset(self, dat, ann, samples): @@ -425,19 +542,31 @@ def get_torch_dataset(self, dat, ann, samples): ann, variable_types, label_mappings = self.encode_labels(ann) # Convert DataFrame to tensor with MPS-compatible dtypes - ann = {col: torch.from_numpy(ann[col].values.copy()).float() if ann[col].dtype in ['float64', 'float32'] - else torch.from_numpy(ann[col].values.copy()) for col in ann.columns} - return MultiOmicDataset(dat, ann, variable_types, features, samples, label_mappings) + ann = { + col: ( + torch.from_numpy(ann[col].values.copy()).float() + if ann[col].dtype in ["float64", "float32"] + else torch.from_numpy(ann[col].values.copy()) + ) + for col in ann.columns + } + return MultiOmicDataset( + dat, ann, variable_types, features, samples, label_mappings + ) def encode_labels(self, df): label_mappings = {} + def encode_column(series): - nonlocal label_mappings # Declare as nonlocal so that we can modify it # Fill NA values with 'missing' # series = series.fillna('missing') if series.name not in self.encoders: - self.encoders[series.name] = OrdinalEncoder(handle_unknown='use_encoded_value', unknown_value=-1) - encoded_series = self.encoders[series.name].fit_transform(series.to_frame()) + self.encoders[series.name] = OrdinalEncoder( + handle_unknown="use_encoded_value", unknown_value=-1 + ) + encoded_series = self.encoders[series.name].fit_transform( + series.to_frame() + ) # NEW: Store encoder for inference mode self.label_encoders[series.name] = self.encoders[series.name] else: @@ -445,60 +574,85 @@ def encode_column(series): # also save label mappings label_mappings[series.name] = { - int(code): label for code, label in enumerate(self.encoders[series.name].categories_[0]) - } + int(code): label + for code, label in enumerate(self.encoders[series.name].categories_[0]) + } return encoded_series.ravel() # Select only the categorical columns - df_categorical = df.select_dtypes(include=['object', 'category']).apply(encode_column) + df_categorical = df.select_dtypes(include=["object", "category"]).apply( + encode_column + ) # Combine the encoded categorical data with the numerical data - df_encoded = pd.concat([df.select_dtypes(exclude=['object', 'category']), df_categorical], axis=1) + df_encoded = pd.concat( + [df.select_dtypes(exclude=["object", "category"]), df_categorical], + axis=1, + ) # Store the variable types - variable_types = {col: 'categorical' for col in df_categorical.columns} - variable_types.update({col: 'numerical' for col in df.select_dtypes(exclude=['object', 'category']).columns}) + variable_types = {col: "categorical" for col in df_categorical.columns} + variable_types.update( + { + col: "numerical" + for col in df.select_dtypes(exclude=["object", "category"]).columns + } + ) return df_encoded, variable_types, label_mappings - def validate_input_data(self, train_dat, test_dat): - print("\n[INFO] ----------------- Checking for problems with the input data ----------------- ") + print( + "\n[INFO] ----------------- Checking for problems with the input data ----------------- " + ) errors = [] warnings = [] + def check_rownames(dat, split): # Check 1: Validate first columns are unique for file_name, df in dat.items(): if not df.index.is_unique: - identifier_type = "Sample labels" if file_name == 'clin' else "Feature names" - errors.append(f"Error in {split}/{file_name}.csv: {identifier_type} in the first column must be unique.") + identifier_type = ( + "Sample labels" if file_name == "clin" else "Feature names" + ) + errors.append( + f"Error in {split}/{file_name}.csv: {identifier_type} in the first column must be unique." + ) def check_sample_labels(dat, split): - clin_samples = set(dat['clin'].index) + clin_samples = set(dat["clin"].index) for file_name, df in dat.items(): - if file_name != 'clin': + if file_name != "clin": omics_samples = set(df.columns) matching_samples = clin_samples.intersection(omics_samples) if not matching_samples: - errors.append(f"Error: No matching sample labels found between {split}/clin.csv and {split}/{file_name}.csv.") + errors.append( + f"Error: No matching sample labels found between " + f"{split}/clin.csv and {split}/{file_name}.csv." + ) elif len(matching_samples) < len(clin_samples): missing_samples = clin_samples - matching_samples - warnings.append(f"Warning: Some sample labels in {split}/clin.csv are missing in {split}/{file_name}.csv: {missing_samples}") + warnings.append( + f"Warning: Some sample labels in {split}/clin.csv are " + f"missing in {split}/{file_name}.csv: {missing_samples}" + ) def check_common_features(train_dat, test_dat): for file_name in train_dat: - if file_name != 'clin' and file_name in test_dat: + if file_name != "clin" and file_name in test_dat: train_features = set(train_dat[file_name].index) test_features = set(test_dat[file_name].index) common_features = train_features.intersection(test_features) if not common_features: - errors.append(f"Error: No common features found between train/{file_name}.csv and test/{file_name}.csv.") + errors.append( + f"Error: No common features found between train/{file_name}.csv and test/{file_name}.csv." + ) - check_rownames(train_dat, 'train') - check_rownames(test_dat, 'test') + check_rownames(train_dat, "train") + check_rownames(test_dat, "test") - check_sample_labels(train_dat, 'train') - check_sample_labels(test_dat, 'test') + check_sample_labels(train_dat, "train") + check_sample_labels(test_dat, "test") check_common_features(train_dat, test_dat) @@ -514,21 +668,15 @@ def check_common_features(train_dat, test_dat): print(f"[ERROR] {i}. {error}") raise Exception("[ERROR] Please correct the above errors and try again.") - if not warnings and not errors: print("[INFO] Data structure is valid with no errors or warnings.") - """ DataImporterInference class for flexynesis inference mode. Add this to flexynesis/data.py """ -import pandas as pd -import numpy as np - - class DataImporterInference: """ @@ -538,21 +686,28 @@ class DataImporterInference: def __init__(self, test_data_path, artifacts_path, verbose=True): from .inference import _load_artifacts + self.test_data_path = test_data_path self.verbose = verbose self.artifacts = _load_artifacts(artifacts_path) # Map artifact keys to expected names (compatibility layer) - self.feature_names = self.artifacts.get("feature_lists", self.artifacts.get("feature_names", {})) - self.scalers = self.artifacts.get("transforms", self.artifacts.get("scalers", {})) + self.feature_names = self.artifacts.get( + "feature_lists", self.artifacts.get("feature_names", {}) + ) + self.scalers = self.artifacts.get( + "transforms", self.artifacts.get("scalers", {}) + ) self.label_encoders = self.artifacts.get("label_encoders", {}) - self.modalities = self.artifacts.get("data_types", self.artifacts.get("modalities", [])) + self.modalities = self.artifacts.get( + "data_types", self.artifacts.get("modalities", []) + ) self.target_variables = self.artifacts.get("target_variables", []) # For early fusion, we need feature lists for original modalities # but artifacts only have 'all'. We need to reconstruct from transforms. - if self.modalities == ['all']: - original_modalities = self.artifacts.get('original_modalities', []) + if self.modalities == ["all"]: + original_modalities = self.artifacts.get("original_modalities", []) # Feature lists for original modalities can be inferred from transforms # since transforms are keyed by original modality names if original_modalities: @@ -563,7 +718,6 @@ def __init__(self, test_data_path, artifacts_path, verbose=True): if self.verbose: print(f"[INFO] Loaded artifacts for modalities: {self.modalities}") - def import_data(self): """Returns MultiOmicDataset object""" test_data = {} @@ -571,7 +725,7 @@ def import_data(self): labels_df = None # Load clinical data - clin_path = os.path.join(self.test_data_path, 'clin.csv') + clin_path = os.path.join(self.test_data_path, "clin.csv") if os.path.exists(clin_path): labels_df = pd.read_csv(clin_path, index_col=0) @@ -579,20 +733,22 @@ def import_data(self): # For early fusion: load original modalities before concatenation # For covariates: load only the omics data (covariates come from clin.csv) modalities_to_load = self.modalities - if self.modalities == ['all']: - modalities_to_load = self.artifacts.get('original_modalities', []) + if self.modalities == ["all"]: + modalities_to_load = self.artifacts.get("original_modalities", []) if not modalities_to_load: - raise ValueError('[ERROR] Early fusion mode but original_modalities not found in artifacts') + raise ValueError( + "[ERROR] Early fusion mode but original_modalities not found in artifacts" + ) else: # Filter out 'covariates' from modalities_to_load # Covariates will be created from clinical data later - modalities_to_load = [m for m in self.modalities if m != 'covariates'] + modalities_to_load = [m for m in self.modalities if m != "covariates"] # Load each modality (skip 'covariates' - it's in clin.csv) for modality in modalities_to_load: - if modality == 'covariates': + if modality == "covariates": continue # Covariates are in clin.csv, not a separate file - file_path = os.path.join(self.test_data_path, f'{modality}.csv') + file_path = os.path.join(self.test_data_path, f"{modality}.csv") if not os.path.exists(file_path): raise FileNotFoundError(f"[ERROR] Required file not found: {file_path}") @@ -610,18 +766,24 @@ def import_data(self): extra = set(df.columns) - set(expected_features) if missing: - raise ValueError(f"[ERROR] {modality}: Missing {len(missing)} features required by model. " - f"Test data must have same preprocessing as training data.") + raise ValueError( + f"[ERROR] {modality}: Missing {len(missing)} features required by model. " + f"Test data must have same preprocessing as training data." + ) if extra and self.verbose: - print(f"[INFO] {modality}: Ignoring {len(extra)} extra features not in training") + print( + f"[INFO] {modality}: Ignoring {len(extra)} extra features not in training" + ) # Select only expected features in correct order df = df[expected_features] # Apply scaling scaler = self.scalers[modality] - df_scaled = pd.DataFrame(scaler.transform(df.values), index=df.index, columns=df.columns) + df_scaled = pd.DataFrame( + scaler.transform(df.values), index=df.index, columns=df.columns + ) # Store as DataFrame for cross-modality intersection test_data[modality] = df_scaled @@ -639,16 +801,22 @@ def import_data(self): ).float() # Create covariates matrix if needed - if 'covariates' in self.modalities and labels_df is not None: - from flexynesis.utils import create_covariate_matrix, get_variable_types - covariate_vars = self.artifacts.get('covariate_vars', []) + if "covariates" in self.modalities and labels_df is not None: + from flexynesis.utils import (create_covariate_matrix, + get_variable_types) + + covariate_vars = self.artifacts.get("covariate_vars", []) if covariate_vars: if self.verbose: print(f"[INFO] Creating covariate matrix for: {covariate_vars}") variable_types_clin = get_variable_types(labels_df) - covariates_df = create_covariate_matrix(covariate_vars, variable_types_clin, labels_df) + covariates_df = create_covariate_matrix( + covariate_vars, variable_types_clin, labels_df + ) # Convert to torch tensor and add to test_data - test_data['covariates'] = torch.from_numpy(covariates_df.T.values).float() + test_data["covariates"] = torch.from_numpy( + covariates_df.T.values + ).float() if samples is None: samples = covariates_df.T.index.tolist() @@ -668,8 +836,7 @@ def import_data(self): # Create mapping from old sample order to new order old_samples = samples # Original order from data CSV df_reordered = pd.DataFrame( - test_data[modality].numpy(), - index=old_samples + test_data[modality].numpy(), index=old_samples ).loc[common_samples] test_data[modality] = torch.from_numpy(df_reordered.values).float() @@ -681,35 +848,50 @@ def import_data(self): valid_mask = ~labels_df[col].isna() encoded = np.full(len(labels_df), -1, dtype=np.int64) - if hasattr(encoder, 'classes_'): # LabelEncoder + if hasattr(encoder, "classes_"): # LabelEncoder if valid_mask.sum() > 0: - encoded[valid_mask] = encoder.transform(labels_df[col][valid_mask].values) + encoded[valid_mask] = encoder.transform( + labels_df[col][valid_mask].values + ) ann_dict[col] = torch.from_numpy(encoded) - variable_types[col] = 'categorical' - label_mappings[col] = {int(c): l for c, l in enumerate(encoder.classes_)} + variable_types[col] = "categorical" + label_mappings[col] = { + int(c): l for c, l in enumerate(encoder.classes_) + } else: # OrdinalEncoder if valid_mask.sum() > 0: - encoded[valid_mask] = encoder.transform(labels_df[col][valid_mask].values.reshape(-1, 1)).ravel() + encoded[valid_mask] = encoder.transform( + labels_df[col][valid_mask].values.reshape(-1, 1) + ).ravel() ann_dict[col] = torch.from_numpy(encoded) - variable_types[col] = 'categorical' - label_mappings[col] = {int(c): l for c, l in enumerate(encoder.categories_[0])} + variable_types[col] = "categorical" + label_mappings[col] = { + int(c): l for c, l in enumerate(encoder.categories_[0]) + } - label_mappings[col][-1] = 'Unknown' # For missing values + label_mappings[col][-1] = "Unknown" # For missing values else: ann_dict[col] = torch.from_numpy(labels_df[col].values).float() - variable_types[col] = 'numerical' + variable_types[col] = "numerical" # Create features dict # For early fusion, get features from scalers since feature_lists only has 'all' - if self.modalities == ['all']: - modalities_for_features = self.artifacts.get('original_modalities', []) + if self.modalities == ["all"]: + modalities_for_features = self.artifacts.get("original_modalities", []) # Get features from scalers for each modality - features = {modality: list(self.scalers[modality].feature_names_in_) for modality in modalities_for_features} + features = { + modality: list(self.scalers[modality].feature_names_in_) + for modality in modalities_for_features + } else: - features = {modality: self.feature_names[modality] for modality in self.modalities} + features = { + modality: self.feature_names[modality] for modality in self.modalities + } # CRITICAL: Reorder test_data dict to match self.modalities order (model expects specific order) - test_data_ordered = {mod: test_data[mod] for mod in self.modalities if mod in test_data} + test_data_ordered = { + mod: test_data[mod] for mod in self.modalities if mod in test_data + } # Create MultiOmicDataset object dataset = MultiOmicDataset( @@ -718,15 +900,18 @@ def import_data(self): variable_types=variable_types, features=features, samples=samples, - label_mappings=label_mappings + label_mappings=label_mappings, ) # Concatenate for early fusion if needed - if self.modalities == ['all']: + if self.modalities == ["all"]: from itertools import chain + # For early fusion, we already loaded original modalities into test_data # Now concatenate them in the SAME ORDER as training - modality_order = self.artifacts.get('original_modalities', list(test_data.keys())) + modality_order = self.artifacts.get( + "original_modalities", list(test_data.keys()) + ) # Concatenate the data tensors concatenated_data = torch.cat([test_data[x] for x in modality_order], dim=1) @@ -735,12 +920,14 @@ def import_data(self): all_features = list(chain(*[dataset.features[x] for x in modality_order])) # Filter to expected features from artifacts - expected_all_features = self.feature_names['all'] - feature_indices = [i for i, f in enumerate(all_features) if f in expected_all_features] + expected_all_features = self.feature_names["all"] + feature_indices = [ + i for i, f in enumerate(all_features) if f in expected_all_features + ] # Update dataset with concatenated data - dataset.dat = {'all': concatenated_data[:, feature_indices]} - dataset.features = {'all': [all_features[i] for i in feature_indices]} + dataset.dat = {"all": concatenated_data[:, feature_indices]} + dataset.features = {"all": [all_features[i] for i in feature_indices]} return dataset @@ -749,16 +936,28 @@ class MultiOmicDataset(Dataset): """A PyTorch dataset for multiomic data. Args: - dat (dict): A dictionary with keys corresponding to different types of data and values corresponding to matrices of the same shape. All matrices must have the same number of samples (rows). + dat (dict): A dictionary with keys corresponding to different types of data and values + corresponding to matrices of the same shape. All matrices must have the same number + of samples (rows). ann (data.frame): Data frame with samples on the rows, sample annotations on the columns - features (list or np.array): A 1D array of feature names with length equal to the number of columns in each matrix. + features (list or np.array): A 1D array of feature names with length equal to the + number of columns in each matrix. samples (list or np.array): A 1D array of sample names with length equal to the number of rows in each matrix. Returns: A PyTorch dataset that can be used for training or evaluation. """ - def __init__(self, dat, ann, variable_types, features, samples, label_mappings, feature_ann=None): + def __init__( + self, + dat, + ann, + variable_types, + features, + samples, + label_mappings, + feature_ann=None, + ): """Initialize the dataset.""" self.dat = dat self.ann = ann @@ -776,14 +975,16 @@ def __getitem__(self, index): Returns: A tuple of two elements: - 1. A dictionary with keys corresponding to the different types of data in the input dictionary `dat`, and values corresponding to the data for the given sample. + 1. A dictionary with keys corresponding to the different + types of data in the input dictionary `dat`, and values + corresponding to the data for the given sample. 2. The label for the given sample. """ subset_dat = {x: self.dat[x][index] for x in self.dat.keys()} subset_ann = {x: self.ann[x][index] for x in self.ann.keys()} return subset_dat, subset_ann, self.samples[index] - def __len__ (self): + def __len__(self): """Get the total number of samples in the dataset. Returns: @@ -805,32 +1006,54 @@ def subset(self, indices): subset_samples = [self.samples[idx] for idx in indices] # Create a new dataset object - return MultiOmicDataset(subset_dat, subset_ann, self.variable_types, self.features, - subset_samples, self.label_mappings, self.feature_ann) + return MultiOmicDataset( + subset_dat, + subset_ann, + self.variable_types, + self.features, + subset_samples, + self.label_mappings, + self.feature_ann, + ) def get_feature_subset(self, feature_df): - """Get a subset of data matrices corresponding to specified features and concatenate them into a pandas DataFrame. + """Get a subset of data matrices corresponding to specified features and + concatenate them into a pandas DataFrame. Args: - feature_df (pandas.DataFrame): A DataFrame which contains at least two columns: 'layer' and 'name'. + feature_df (pandas.DataFrame): A DataFrame which contains at least + two columns: 'layer' and 'name'. Returns: A pandas DataFrame that concatenates the data matrices for the specified features from all layers. """ # Convert the DataFrame to a dictionary - feature_dict = feature_df.groupby('layer')['name'].apply(list).to_dict() + feature_dict = feature_df.groupby("layer")["name"].apply(list).to_dict() dfs = [] for layer, features in feature_dict.items(): if layer in self.dat: # Create a dictionary to look up indices by feature name for each layer - feature_index_dict = {feature: i for i, feature in enumerate(self.features[layer])} + feature_index_dict = { + feature: i for i, feature in enumerate(self.features[layer]) + } # Get the indices for the requested features - indices = [feature_index_dict[feature] for feature in features if feature in feature_index_dict] + indices = [ + feature_index_dict[feature] + for feature in features + if feature in feature_index_dict + ] # Subset the data matrix for the current layer using the indices subset = self.dat[layer][:, indices] # Convert the subset to a pandas DataFrame, add the layer name as a prefix to each column name - df = pd.DataFrame(subset, columns=[f'{layer}_{feature}' for feature in features if feature in feature_index_dict]) + df = pd.DataFrame( + subset, + columns=[ + f"{layer}_{feature}" + for feature in features + if feature in feature_index_dict + ], + ) dfs.append(df) else: print(f"Layer {layer} not found in the dataset.") @@ -844,9 +1067,12 @@ def get_feature_subset(self, feature_df): return result def get_dataset_stats(self): - stats = {': '.join(['feature_count in', x]): self.dat[x].shape[1] for x in self.dat.keys()} - stats['sample_count'] = len(self.samples) - return(stats) + stats = { + ": ".join(["feature_count in", x]): self.dat[x].shape[1] + for x in self.dat.keys() + } + stats["sample_count"] = len(self.samples) + return stats # given a MultiOmicDataset object, convert to Triplets (anchor,positive,negative) @@ -863,14 +1089,19 @@ def __init__(self, mydataset, main_var): self.labels_set, self.label_to_indices = self.get_label_indices(labels) # Valid anchor indices are those without NA labels - self.valid_indices = [i for i, label in enumerate(labels) if not np.isnan(label)] + self.valid_indices = [ + i for i, label in enumerate(labels) if not np.isnan(label) + ] def __getitem__(self, index): # We only use valid non-NA indices for anchors real_index = self.valid_indices[index] # get anchor sample and its label - anchor, y_dict = self.dataset[real_index][0], self.dataset[real_index][1] + anchor, y_dict = ( + self.dataset[real_index][0], + self.dataset[real_index][1], + ) # choose another sample with same label label = y_dict[self.main_var].item() @@ -881,11 +1112,12 @@ def __getitem__(self, index): # choose another sample with a different label # possible negative labels include NA import random + negative_label = random.choice(list(self.labels_set - set([label]))) negative_index = np.random.choice(self.label_to_indices[negative_label]) - pos = self.dataset[positive_index][0] # positive example - neg = self.dataset[negative_index][0] # negative example + pos = self.dataset[positive_index][0] # positive example + neg = self.dataset[negative_index][0] # negative example return anchor, pos, neg, y_dict def __len__(self): @@ -893,11 +1125,12 @@ def __len__(self): def get_label_indices(self, labels_array): # Filter out NaNs for a clean set of valid classes - valid_labels = [l for l in labels_array if not np.isnan(l)] + valid_labels = [label for label in labels_array if not np.isnan(label)] labels_set = set(valid_labels) - label_to_indices = {label: np.where(labels_array == label)[0] - for label in labels_set} + label_to_indices = { + label: np.where(labels_array == label)[0] for label in labels_set + } # Handle NA as a single separate group (if any exist) na_indices = np.where(np.isnan(labels_array))[0] @@ -912,11 +1145,15 @@ class MultiOmicDatasetNW(Dataset): def __init__(self, multiomic_dataset, interaction_df, modality_order=None): self.multiomic_dataset = multiomic_dataset self.interaction_df = interaction_df - self.modality_order = modality_order if modality_order else sorted(multiomic_dataset.dat.keys()) + self.modality_order = ( + modality_order if modality_order else sorted(multiomic_dataset.dat.keys()) + ) # Compute union of features in the data matrices that also appear in the network self.common_features = self.find_union_features() - self.gene_to_index = {gene: idx for idx, gene in enumerate(self.common_features)} + self.gene_to_index = { + gene: idx for idx, gene in enumerate(self.common_features) + } self.edge_index = self.create_edge_index() self.samples = self.multiomic_dataset.samples self.variable_types = self.multiomic_dataset.variable_types @@ -927,30 +1164,47 @@ def __init__(self, multiomic_dataset, interaction_df, modality_order=None): self.node_features_tensor = self.precompute_node_features() # Store labels for all samples - self.labels = {target_name: labels for target_name, labels in self.multiomic_dataset.ann.items()} + self.labels = { + target_name: labels + for target_name, labels in self.multiomic_dataset.ann.items() + } def find_union_features(self): # Find the union of all features in the multiomic dataset - all_omic_features = set().union(*(set(features) for features in self.multiomic_dataset.features.values())) + all_omic_features = set().union( + *(set(features) for features in self.multiomic_dataset.features.values()) + ) # Find the union of proteins involved in interactions - interaction_genes = set(self.interaction_df['protein1']).union(set(self.interaction_df['protein2'])) + interaction_genes = set(self.interaction_df["protein1"]).union( + set(self.interaction_df["protein2"]) + ) # Return the intersection of omic features and interaction genes return sorted(list(all_omic_features.intersection(interaction_genes))) def create_edge_index(self): # Create edges only if both proteins are within the available features filtered_df = self.interaction_df[ - (self.interaction_df['protein1'].isin(self.common_features)) & - (self.interaction_df['protein2'].isin(self.common_features)) + (self.interaction_df["protein1"].isin(self.common_features)) + & (self.interaction_df["protein2"].isin(self.common_features)) + ] + edge_list = [ + ( + self.gene_to_index[row["protein1"]], + self.gene_to_index[row["protein2"]], + ) + for index, row in filtered_df.iterrows() ] - edge_list = [(self.gene_to_index[row['protein1']], self.gene_to_index[row['protein2']]) for index, row in filtered_df.iterrows()] return torch.tensor(edge_list, dtype=torch.long).t() def precompute_node_features(self): num_samples = len(self.samples) num_nodes = len(self.common_features) num_data_types = len(self.multiomic_dataset.dat) - all_features = torch.full((num_samples, num_nodes, num_data_types), float('nan'), dtype=torch.float) + all_features = torch.full( + (num_samples, num_nodes, num_data_types), + float("nan"), + dtype=torch.float, + ) # CRITICAL: Use sorted keys to ensure consistent order across training/inference data_types_ordered = sorted(self.multiomic_dataset.dat.keys()) @@ -958,16 +1212,24 @@ def precompute_node_features(self): data_matrix = self.multiomic_dataset.dat[data_type] feature_indices = { gene: self.multiomic_dataset.features[data_type].get_loc(gene) - for gene in self.common_features if gene in self.multiomic_dataset.features[data_type] + for gene in self.common_features + if gene in self.multiomic_dataset.features[data_type] } - valid_indices = torch.tensor(list(feature_indices.values()), dtype=torch.long) - feature_positions = torch.tensor([self.gene_to_index[gene] for gene in feature_indices.keys()], dtype=torch.long) + valid_indices = torch.tensor( + list(feature_indices.values()), dtype=torch.long + ) + feature_positions = torch.tensor( + [self.gene_to_index[gene] for gene in feature_indices.keys()], + dtype=torch.long, + ) # Fill in the available data all_features[:, feature_positions, i] = data_matrix[:, valid_indices] # Precompute medians for all data types, ignoring NaN values - medians = torch.nanmedian(all_features, dim=1, keepdim=True).values # Use .values to get the actual median tensor + medians = torch.nanmedian( + all_features, dim=1, keepdim=True + ).values # Use .values to get the actual median tensor # Replace all NaN values in all_features with their corresponding median values isnan = torch.isnan(all_features) @@ -982,10 +1244,11 @@ def subset(self, indices): # Create a new instance of MultiOmicDatasetNW with the subsetted multiomic dataset return MultiOmicDatasetNW(dataset_subset, self.interaction_df.copy()) - def __getitem__(self, idx): node_features_tensor = self.node_features_tensor[idx] - y_dict = {target_name: self.labels[target_name][idx] for target_name in self.labels} + y_dict = { + target_name: self.labels[target_name][idx] for target_name in self.labels + } return node_features_tensor, y_dict, self.samples[idx] def __len__(self): @@ -1002,13 +1265,19 @@ def print_stats(self): # Calculate degree for each node degrees = torch.zeros(num_nodes, dtype=torch.long) degrees.index_add_(0, self.edge_index[0], torch.ones_like(self.edge_index[0])) - degrees.index_add_(0, self.edge_index[1], torch.ones_like(self.edge_index[1])) # For undirected graphs + degrees.index_add_( + 0, self.edge_index[1], torch.ones_like(self.edge_index[1]) + ) # For undirected graphs num_singletons = torch.sum(degrees == 0).item() non_singletons = degrees[degrees > 0] - mean_edges_per_node = non_singletons.float().mean().item() if len(non_singletons) > 0 else 0 - median_edges_per_node = non_singletons.median().item() if len(non_singletons) > 0 else 0 + mean_edges_per_node = ( + non_singletons.float().mean().item() if len(non_singletons) > 0 else 0 + ) + median_edges_per_node = ( + non_singletons.median().item() if len(non_singletons) > 0 else 0 + ) max_edges = degrees.max().item() print("Dataset Statistics:") @@ -1016,10 +1285,15 @@ def print_stats(self): print(f"Total number of edges: {num_edges}") print(f"Number of node features per node: {num_node_features}") print(f"Number of singletons (nodes with no edges): {num_singletons}") - print(f"Mean number of edges per node (excluding singletons): {mean_edges_per_node:.2f}") - print(f"Median number of edges per node (excluding singletons): {median_edges_per_node}") + print( + f"Mean number of edges per node (excluding singletons): {mean_edges_per_node:.2f}" + ) + print( + f"Median number of edges per node (excluding singletons): {median_edges_per_node}" + ) print(f"Max number of edges per node: {max_edges}") + def get_flexynesis_cache_dir() -> Path: """Resolve a writable cache directory for Flexynesis.""" env_cache = os.getenv("FLEXYNESIS_CACHE") @@ -1046,26 +1320,41 @@ class STRING(PYGDataset): - Downloads once, processes once per (version, organism, node_name). - Safe for concurrent jobs via a file lock. """ + base_folder = "STRING" version = "12.0" files = ("links", "aliases") - url = ("https://stringdb-downloads.org/download/" - "protein.{data}.v{version}/" - "{organism}.protein.{data}.v{version}.txt.gz") - - def __init__(self, root: str | None = None, organism: int = 9606, node_name: str = "gene_name") -> None: + url = ( + "https://stringdb-downloads.org/download/" + "protein.{data}.v{version}/" + "{organism}.protein.{data}.v{version}.txt.gz" + ) + + def __init__( + self, + root: str | None = None, + organism: int = 9606, + node_name: str = "gene_name", + ) -> None: self.organism = organism self.node_name = node_name # ---------- resolve central cache ---------- - cache_root = get_flexynesis_cache_dir() / self.base_folder / f"v{self.version}" / str(self.organism) + cache_root = ( + get_flexynesis_cache_dir() + / self.base_folder + / f"v{self.version}" + / str(self.organism) + ) # layout self._cache_root = cache_root self._cache_raw_dir = cache_root / "raw" self._cache_processed_dir = cache_root / "processed" self._processed_basename = f"graph_{self.node_name}.csv" - self._cache_processed_path = self._cache_processed_dir / self._processed_basename + self._cache_processed_path = ( + self._cache_processed_dir / self._processed_basename + ) self._cache_raw_dir.mkdir(parents=True, exist_ok=True) self._cache_processed_dir.mkdir(parents=True, exist_ok=True) @@ -1081,7 +1370,9 @@ def __init__(self, root: str | None = None, organism: int = 9606, node_name: str super().__init__(str(self._cache_root)) # Load processed data - self.graph_df = pd.read_csv(self.processed_paths[0], sep=",", header=0, index_col=0) + self.graph_df = pd.read_csv( + self.processed_paths[0], sep=",", header=0, index_col=0 + ) # -------------------- PyG stubs -------------------- def len(self) -> int: @@ -1123,7 +1414,7 @@ def cache_dir(self) -> Path: return self._cache_root -def read_user_graph(fpath, sep=None, header='infer', **pd_read_csv_kw): +def read_user_graph(fpath, sep=None, header="infer", **pd_read_csv_kw): """Read user-provided gene-gene interaction network file. This function reads a custom network file and validates that it contains @@ -1159,18 +1450,19 @@ def read_user_graph(fpath, sep=None, header='infer', **pd_read_csv_kw): if sep is None: # Use CSV Sniffer for robust separator detection import csv - with open(fpath, 'r') as f: + + with open(fpath, "r") as f: # Read first few lines for better detection sample = f.read(4096) try: sniffer = csv.Sniffer() - dialect = sniffer.sniff(sample, delimiters='\t,| ') + dialect = sniffer.sniff(sample, delimiters="\t,| ") sep = dialect.delimiter print(f"[INFO] Auto-detected separator using CSV Sniffer: {repr(sep)}") except csv.Error: # Fallback to tab if Sniffer fails - sep = '\t' + sep = "\t" print(f"[INFO] CSV Sniffer failed, using default separator: {repr(sep)}") # Read the file @@ -1186,21 +1478,51 @@ def read_user_graph(fpath, sep=None, header='infer', **pd_read_csv_kw): # Get column names (handle cases with or without header) if header is None or (isinstance(header, int) and header < 0): # No header - assign default names - print("[INFO] No header detected. Assuming first 3 columns are: GeneA, GeneB, Score") - df.columns = [f'col_{i}' for i in range(len(df.columns))] - col_gene_a = 'col_0' - col_gene_b = 'col_1' - col_score = 'col_2' + print( + "[INFO] No header detected. Assuming first 3 columns are: GeneA, GeneB, Score" + ) + df.columns = [f"col_{i}" for i in range(len(df.columns))] + col_gene_a = "col_0" + col_gene_b = "col_1" + col_score = "col_2" else: # Header present - use hybrid scoring to intelligently identify columns from difflib import SequenceMatcher # Define candidate keywords for each column type candidates = { - 'gene_a': ['genea', 'gene_a', 'gene1', 'protein1', 'node1', 'source', 'from'], - 'gene_b': ['geneb', 'gene_b', 'gene2', 'protein2', 'node2', 'target', 'to'], - 'score': ['score', 'weight', 'combined_score', 'correlation', 'confidence', - 'value', 'strength', 'coef', 'coefficient', 'corr', 'pvalue', 'p_value'] + "gene_a": [ + "genea", + "gene_a", + "gene1", + "protein1", + "node1", + "source", + "from", + ], + "gene_b": [ + "geneb", + "gene_b", + "gene2", + "protein2", + "node2", + "target", + "to", + ], + "score": [ + "score", + "weight", + "combined_score", + "correlation", + "confidence", + "value", + "strength", + "coef", + "coefficient", + "corr", + "pvalue", + "p_value", + ], } def score_column_match(col, col_idx, category, total_cols): @@ -1230,15 +1552,15 @@ def score_column_match(col, col_idx, category, total_cols): # Signal 4: Position hints (10 points) # First column likely gene_a, second likely gene_b, third+ likely score - if category == 'gene_a' and col_idx == 0: + if category == "gene_a" and col_idx == 0: total_score += 10 - elif category == 'gene_b' and col_idx == 1: + elif category == "gene_b" and col_idx == 1: total_score += 10 - elif category == 'score' and col_idx >= 2: + elif category == "score" and col_idx >= 2: total_score += 10 # Signal 5: Data type hint for score (5 points) - if category == 'score': + if category == "score": if pd.api.types.is_numeric_dtype(df[col]): total_score += 5 @@ -1248,7 +1570,7 @@ def score_column_match(col, col_idx, category, total_cols): total_cols = len(df.columns) matches = {} - for category in ['gene_a', 'gene_b', 'score']: + for category in ["gene_a", "gene_b", "score"]: best_col = None best_score = 0 @@ -1261,29 +1583,37 @@ def score_column_match(col, col_idx, category, total_cols): matches[category] = (best_col, best_score) # Extract matched columns and scores - col_gene_a, score_a = matches['gene_a'] - col_gene_b, score_b = matches['gene_b'] - col_score, score_s = matches['score'] + col_gene_a, score_a = matches["gene_a"] + col_gene_b, score_b = matches["gene_b"] + col_score, score_s = matches["score"] # Check if confidence is too low min_threshold = 30 - if score_a < min_threshold or score_b < min_threshold or score_s < min_threshold: - print(f"[WARNING] Low confidence in column detection. Using first 3 columns as fallback.") + if ( + score_a < min_threshold + or score_b < min_threshold + or score_s < min_threshold + ): + print( + "[WARNING] Low confidence in column detection. Using first 3 columns as fallback." + ) col_gene_a = df.columns[0] col_gene_b = df.columns[1] col_score = df.columns[2] else: - print(f"[INFO] Detected columns - GeneA: '{col_gene_a}', GeneB: '{col_gene_b}', Score: '{col_score}'") + print( + f"[INFO] Detected columns - GeneA: '{col_gene_a}', GeneB: '{col_gene_b}', Score: '{col_score}'" + ) # Extract only the required 3 columns and standardize names result_df = df[[col_gene_a, col_gene_b, col_score]].copy() - result_df.columns = ['protein1', 'protein2', 'combined_score'] + result_df.columns = ["protein1", "protein2", "combined_score"] # Validate data types # Convert score to numeric if not already - if not pd.api.types.is_numeric_dtype(result_df['combined_score']): + if not pd.api.types.is_numeric_dtype(result_df["combined_score"]): try: - result_df['combined_score'] = pd.to_numeric(result_df['combined_score']) + result_df["combined_score"] = pd.to_numeric(result_df["combined_score"]) except ValueError: raise ValueError( f"Score column must contain numeric values. " @@ -1294,14 +1624,19 @@ def score_column_match(col, col_idx, category, total_cols): original_len = len(result_df) result_df = result_df.dropna() if len(result_df) < original_len: - print(f"[WARNING] Removed {original_len - len(result_df)} rows with missing values") + print( + f"[WARNING] Removed {original_len - len(result_df)} rows with missing values" + ) print(f"[INFO] Successfully loaded user graph: {len(result_df)} interactions") - print(f"[INFO] Score range: [{result_df['combined_score'].min():.4f}, {result_df['combined_score'].max():.4f}]") + print( + f"[INFO] Score range: [{result_df['combined_score'].min():.4f}, {result_df['combined_score'].max():.4f}]" + ) return result_df -def read_stringdb_links(fname, top_neighbors = 5): + +def read_stringdb_links(fname, top_neighbors=5): """ Reads and processes a STRING database file to extract and rank protein-protein interactions. @@ -1327,19 +1662,26 @@ def read_stringdb_links(fname, top_neighbors = 5): """ df = pd.read_csv(fname, header=0, sep=" ") df = df[df.combined_score > 400] - df_expanded = pd.concat([ - df.rename(columns={'protein1': 'protein', 'protein2': 'partner'}), - df.rename(columns={'protein2': 'protein', 'protein1': 'partner'}) - ]) + df_expanded = pd.concat( + [ + df.rename(columns={"protein1": "protein", "protein2": "partner"}), + df.rename(columns={"protein2": "protein", "protein1": "partner"}), + ] + ) # Sort the expanded DataFrame by 'combined_score' in descending order - df_expanded_sorted = df_expanded.sort_values(by='combined_score', ascending=False) - # Reduce to unique interactions to avoid counting duplicates - df_expanded_unique = df_expanded_sorted.drop_duplicates(subset=['protein', 'partner']) - top_interactions = df_expanded_unique.groupby('protein').head(top_neighbors) - df = top_interactions.rename(columns={'protein': 'protein1', 'partner': 'protein2'}) - df[["protein1", "protein2"]] = df[["protein1", "protein2"]].map(lambda a: a.split(".")[-1]) + df_expanded_sorted = df_expanded.sort_values(by="combined_score", ascending=False) + # Reduce to unique interactions to avoid counting duplicates + df_expanded_unique = df_expanded_sorted.drop_duplicates( + subset=["protein", "partner"] + ) + top_interactions = df_expanded_unique.groupby("protein").head(top_neighbors) + df = top_interactions.rename(columns={"protein": "protein1", "partner": "protein2"}) + df[["protein1", "protein2"]] = df[["protein1", "protein2"]].map( + lambda a: a.split(".")[-1] + ) return df + def read_stringdb_aliases(fname: str, node_name: str) -> dict[str, str]: if node_name == "gene_id": source = ("Ensembl_HGNC_ensembl_gene_id", "Ensembl_gene") @@ -1385,6 +1727,7 @@ def fn(a): graph_df[["protein1", "protein2"]] = graph_df[["protein1", "protein2"]].map(fn) return graph_df + def stringdb_links_to_list(df): lst = df[["protein1", "protein2"]].to_numpy().tolist() return lst @@ -1401,7 +1744,7 @@ def split_by_median(tensor_dict): # Convert to categorical, but preserve NaNs tensor_cat = (tensor > median_val).float() - tensor_cat[torch.isnan(tensor)] = float('nan') + tensor_cat[torch.isnan(tensor)] = float("nan") new_dict[key] = tensor_cat else: # If tensor is not numerical, leave it as it is diff --git a/flexynesis/feature_selection.py b/flexynesis/feature_selection.py index 5b072d79..2b65cb1b 100644 --- a/flexynesis/feature_selection.py +++ b/flexynesis/feature_selection.py @@ -1,23 +1,22 @@ -# Tools to do feature selection +# Tools to do feature selection import numpy as np import pandas as pd +from scipy.sparse import csgraph, csr_matrix, diags from scipy.spatial.distance import pdist, squareform from sklearn.neighbors import kneighbors_graph -from scipy.sparse import csgraph - -from scipy.sparse import csr_matrix, diags from tqdm import tqdm + def laplacian_score(X, k=5, t=None): """ Calculate Laplacian Score for each feature in the dataset. - + Parameters: X: numpy array (n_samples, n_features) - Input data k: int - Number of nearest neighbors to consider t: float - Heat kernel parameter (optional) - + Returns: scores: (n_features,) - Laplacian Scores for each feature """ @@ -27,11 +26,11 @@ def laplacian_score(X, k=5, t=None): dist_matrix = squareform(pdist(X)) # Calculate the k-nearest neighbors adjacency matrix - W = kneighbors_graph(X, k, mode='connectivity', include_self=True) + W = kneighbors_graph(X, k, mode="connectivity", include_self=True) # Compute the heat kernel (optional) if t is not None: - W = csr_matrix(np.exp(-dist_matrix ** 2 / t)) + W = csr_matrix(np.exp(-(dist_matrix**2) / t)) # Normalize the adjacency matrix D = np.array(W.sum(axis=1)).flatten() @@ -46,8 +45,8 @@ def laplacian_score(X, k=5, t=None): # Calculate Laplacian Score for each feature scores = [] D = diags(D) # Convert to sparse diagonal matrix for efficient dot product - for i in tqdm(range(n_features), desc='Calculating Laplacian scores'): - fi = X[:,i] + for i in tqdm(range(n_features), desc="Calculating Laplacian scores"): + fi = X[:, i] fi = fi - np.dot(S, fi).sum() / n_samples L_score = np.dot(fi.T, L @ fi) / np.dot(fi.T, D @ fi) scores.append(L_score) @@ -57,14 +56,14 @@ def laplacian_score(X, k=5, t=None): def remove_redundant_features(X, laplacian_scores, threshold, topN=None): """ - Removes features from the dataset based on correlation and Laplacian scores, optionally + Removes features from the dataset based on correlation and Laplacian scores, optionally keeping only the top N features based on their scores, and manages redundant features. - This function evaluates features in a dataset for redundancy by measuring the correlation - between them. If the absolute correlation between two features exceeds a specified threshold, - the feature with the higher Laplacian score (a lower score is better) is considered redundant - and is marked for removal. - If the number of selected features is less than a specified top N, additional features are + This function evaluates features in a dataset for redundancy by measuring the correlation + between them. If the absolute correlation between two features exceeds a specified threshold, + the feature with the higher Laplacian score (a lower score is better) is considered redundant + and is marked for removal. + If the number of selected features is less than a specified top N, additional features are included from the redundant set based on their Laplacian scores until the top N count is reached. Parameters @@ -98,13 +97,15 @@ def remove_redundant_features(X, laplacian_scores, threshold, topN=None): indicate features more important for preserving data structure. - Features initially marked as redundant but included during the top-up process to meet the topN requirement are removed from the redundant_features_df before it is returned. - """ + """ correlation_matrix = np.corrcoef(X.T) - ranked_indices = np.argsort(laplacian_scores) # Assuming minimizing laplacian_scores + ranked_indices = np.argsort( + laplacian_scores + ) # Assuming minimizing laplacian_scores selected_features = [] redundant_features = {} - for idx in tqdm(ranked_indices, desc='Filtering redundant features'): + for idx in tqdm(ranked_indices, desc="Filtering redundant features"): correlated = False for selected_idx in selected_features: correlation = np.abs(correlation_matrix[idx, selected_idx]) @@ -115,15 +116,19 @@ def remove_redundant_features(X, laplacian_scores, threshold, topN=None): break if correlated: - redundant_features[idx] = {'correlated_with': correlated_feature, 'correlation_score': correlation_score} + redundant_features[idx] = { + "correlated_with": correlated_feature, + "correlation_score": correlation_score, + } else: selected_features.append(idx) # Topping up from redundant features if fewer than topN features are selected if topN is not None and len(selected_features) < topN: - shortfall = topN - len(selected_features) # Sort redundant features by their laplacian score, prioritizing lower scores - sorted_redundant_indices = sorted(redundant_features.keys(), key=lambda x: laplacian_scores[x]) + sorted_redundant_indices = sorted( + redundant_features.keys(), key=lambda x: laplacian_scores[x] + ) topped_up_features = [] for idx in sorted_redundant_indices: if len(selected_features) < topN: @@ -135,24 +140,29 @@ def remove_redundant_features(X, laplacian_scores, threshold, topN=None): # Remove topped-up features from the redundant_features dictionary for idx in topped_up_features: del redundant_features[idx] - + if len(redundant_features) > 0: # Convert redundant_features dictionary to DataFrame - redundant_features_df = pd.DataFrame([ - { - "feature": X.columns[idx], - "correlated_with": X.columns[redundant_features[idx]['correlated_with']], - "correlation_score": redundant_features[idx]['correlation_score'] - } - for idx in redundant_features - ]) + redundant_features_df = pd.DataFrame( + [ + { + "feature": X.columns[idx], + "correlated_with": X.columns[ + redundant_features[idx]["correlated_with"] + ], + "correlation_score": redundant_features[idx]["correlation_score"], + } + for idx in redundant_features + ] + ) return X.columns[selected_features], redundant_features_df else: return X.columns[selected_features], pd.DataFrame() + def filter_by_laplacian(X, layer, k=5, t=None, topN=100, correlation_threshold=0.9): """ - Filters features in a dataset based on Laplacian scores and removes highly correlated features, + Filters features in a dataset based on Laplacian scores and removes highly correlated features, retaining only the top N features with the lowest scores and optionally considering correlation. This function computes Laplacian scores for each feature in the dataset to measure its importance @@ -170,13 +180,13 @@ def filter_by_laplacian(X, layer, k=5, t=None, topN=100, correlation_threshold=0 k : int, optional The number of nearest neighbors to consider for computing the Laplacian score (default is 5). t : float, optional - The heat kernel parameter for Laplacian score computation. If None, the default behavior + The heat kernel parameter for Laplacian score computation. If None, the default behavior applies without a heat kernel (default is None). topN : int, optional The number of top features to keep based on the lowest Laplacian scores (default is 100). correlation_threshold : float, optional - The Pearson correlation coefficient threshold for identifying and removing highly correlated - features. Features with a correlation above this threshold are considered redundant + The Pearson correlation coefficient threshold for identifying and removing highly correlated + features. Features with a correlation above this threshold are considered redundant (default is 0.9). Returns @@ -198,21 +208,32 @@ def filter_by_laplacian(X, layer, k=5, t=None, topN=100, correlation_threshold=0 correlated with any already selected feature. - The process may select additional features beyond `topN` before correlation filtering to ensure that the best candidates are considered. The final number of features, however, is pruned to `topN`. - """ - print("[INFO] Implementing feature selection using laplacian score for layer:",layer,"with ",X.shape[1],"features"," and ",X.shape[0], " samples ") - - feature_log = pd.DataFrame({'feature': X.columns, 'laplacian_score': np.nan}) + """ + print( + "[INFO] Implementing feature selection using laplacian score for layer:", + layer, + "with ", + X.shape[1], + "features", + " and ", + X.shape[0], + " samples ", + ) + + feature_log = pd.DataFrame({"feature": X.columns, "laplacian_score": np.nan}) # only apply filtering if topN < n_features - if topN >= X.shape[1]: - print("[INFO] No feature selection applied. Returning original matrix. Demanded # of features is ", - "larger than existing number of features") + if topN >= X.shape[1]: + print( + "[INFO] No feature selection applied. Returning original matrix. Demanded # of features is ", + "larger than existing number of features", + ) return X, feature_log - + # compute laplacian scores scores = laplacian_score(X.values, k, t) - - feature_log = pd.DataFrame({'feature': X.columns, 'laplacian_score': scores}) - + + feature_log = pd.DataFrame({"feature": X.columns, "laplacian_score": scores}) + # Sort the features based on their Laplacian Scores sorted_indices = np.argsort(scores) selected_feature_indices = sorted_indices[:topN] @@ -220,27 +241,34 @@ def filter_by_laplacian(X, layer, k=5, t=None, topN=100, correlation_threshold=0 if correlation_threshold < 1: # Choose the topN + 10% features with the lowest Laplacian Scores - # this is done to avoid unnecessary computation of correlation for all features. + # this is done to avoid unnecessary computation of correlation for all features. topN_extended = int(topN + 0.10 * X.shape[1]) - topN_extended = min(topN_extended, X.shape[1]) # Ensure we don't exceed the number of features + topN_extended = min( + topN_extended, X.shape[1] + ) # Ensure we don't exceed the number of features selected_features = sorted_indices[:topN_extended] # Remove redundancy from topN + 10% features - selected_features, redundant_features_df = remove_redundant_features(X[X.columns[selected_feature_indices]], - scores[selected_feature_indices], correlation_threshold, - topN) + selected_features, redundant_features_df = remove_redundant_features( + X[X.columns[selected_feature_indices]], + scores[selected_feature_indices], + correlation_threshold, + topN, + ) # Prune down to topN features selected_features = selected_features[:topN] - - # if any redundant features found, merge feature log with info from this. + + # if any redundant features found, merge feature log with info from this. if not redundant_features_df.empty: # record the table of features which were removed due to redundancy - feature_log = pd.merge(feature_log, redundant_features_df, on = 'feature', how = 'outer') + feature_log = pd.merge( + feature_log, redundant_features_df, on="feature", how="outer" + ) # Extract the selected features from the dataset X_selected = X[selected_features] - - feature_log['selected'] = False - feature_log.loc[feature_log['feature'].isin(selected_features),'selected'] = True - + + feature_log["selected"] = False + feature_log.loc[feature_log["feature"].isin(selected_features), "selected"] = True + return X_selected, feature_log diff --git a/flexynesis/generate_coexpression_network.py b/flexynesis/generate_coexpression_network.py index ecf2150d..71ccb556 100644 --- a/flexynesis/generate_coexpression_network.py +++ b/flexynesis/generate_coexpression_network.py @@ -15,54 +15,56 @@ Input format: CSV/TSV file with genes as rows and samples as columns First column should be gene names/IDs - + Output format: CSV/TSV file with columns: GeneA, GeneB, Score Format matches input file extension (.csv or .tsv) """ -import pandas as pd -import numpy as np -from tqdm import tqdm import argparse import sys + +import pandas as pd import torch +from tqdm import tqdm -def build_network(expr_df, method='spearman', min_correlation=0.3, top_k=10, device=None): +def build_network( + expr_df, method="spearman", min_correlation=0.3, top_k=10, device=None +): """ Build co-expression network without storing full correlation matrix. Computes correlations in batches and immediately filters to save memory. - + Args: expr_df: DataFrame with genes as rows, samples as columns method: 'spearman' or 'pearson' min_correlation: Minimum absolute correlation threshold top_k: Keep top K neighbors per gene device: torch device (cuda/mps/cpu). Auto-detected if None. - + Returns: List of edge dictionaries with GeneA, GeneB, Score """ # Auto-detect device if not specified if device is None: if torch.cuda.is_available(): - device = torch.device('cuda') + device = torch.device("cuda") elif torch.backends.mps.is_available(): - device = torch.device('mps') + device = torch.device("mps") else: - device = torch.device('cpu') - + device = torch.device("cpu") + print(f"Using device: {device}") print(f"Calculating {method} correlations for {len(expr_df)} genes...") - + # Convert to torch tensor and move to GPU data = torch.tensor(expr_df.values, dtype=torch.float32, device=device) n_genes = data.shape[0] gene_names = expr_df.index.tolist() - - if method == 'spearman': + + if method == "spearman": # Convert to ranks for Spearman - process in batches with progress bar print("Computing ranks...") batch_size = 5000 @@ -71,83 +73,91 @@ def build_network(expr_df, method='spearman', min_correlation=0.3, top_k=10, dev for i in range(0, n_genes, batch_size): end_i = min(i + batch_size, n_genes) batch = data[i:end_i] - ranks[i:end_i] = torch.argsort(torch.argsort(batch, dim=1), dim=1).float() + ranks[i:end_i] = torch.argsort( + torch.argsort(batch, dim=1), dim=1 + ).float() pbar.update(end_i - i) data = ranks - elif method != 'pearson': + elif method != "pearson": raise ValueError(f"Unknown method: {method}. Use 'spearman' or 'pearson'") - + # Standardize for correlation computation print("Standardizing data...") data_mean = data.mean(dim=1, keepdim=True) data_std = data.std(dim=1, keepdim=True, unbiased=False) data_normalized = (data - data_mean) / (data_std + 1e-8) - + # Compute correlations in batches and extract top-k on the fly batch_size = 1000 # Process 1000 genes at a time edges = [] - - print(f"Computing correlations and building network (min |r| = {min_correlation}, top {top_k} per gene)...") + + print( + f"Computing correlations and building network (min |r| = {min_correlation}, top {top_k} per gene)..." + ) with tqdm(total=n_genes, desc="Processing genes") as pbar: for i in range(0, n_genes, batch_size): end_i = min(i + batch_size, n_genes) batch = data_normalized[i:end_i] - + # Compute correlation for this batch with all genes corr_batch = torch.mm(batch, data_normalized.T) / data.shape[1] - + # Process each gene in the batch for local_idx, global_idx in enumerate(range(i, end_i)): gene_corr = corr_batch[local_idx] gene_name = gene_names[global_idx] - + # Remove self-correlation gene_corr[global_idx] = 0 - + # Get absolute correlations abs_corr = gene_corr.abs() - + # Filter by threshold mask = abs_corr >= min_correlation - + # Get top-k if mask.sum() > top_k: # Get indices of top-k values - top_k_values, top_k_indices = torch.topk(abs_corr, min(top_k, len(abs_corr))) + top_k_values, top_k_indices = torch.topk( + abs_corr, min(top_k, len(abs_corr)) + ) # Filter to only those above threshold valid_mask = top_k_values >= min_correlation top_k_indices = top_k_indices[valid_mask] else: # All values above threshold top_k_indices = torch.where(mask)[0] - + # Add edges for neighbor_idx in top_k_indices: neighbor_idx = neighbor_idx.item() score = abs_corr[neighbor_idx].item() - edges.append({ - 'GeneA': gene_name, - 'GeneB': gene_names[neighbor_idx], - 'Score': score - }) - + edges.append( + { + "GeneA": gene_name, + "GeneB": gene_names[neighbor_idx], + "Score": score, + } + ) + pbar.update(end_i - i) - + return edges def generate_coexpression_network( input_file, output_file, - method='spearman', + method="spearman", min_correlation=0.3, top_k=10, remove_self_loops=True, - remove_duplicates=True + remove_duplicates=True, ): """ Main function to generate co-expression network. - + Args: input_file: Path to gene expression CSV/TSV file output_file: Path to output network file (CSV or TSV) @@ -160,97 +170,100 @@ def generate_coexpression_network( print("=" * 70) print("Co-expression Network Generator") print("=" * 70) - + # Load expression data print(f"\n[1/3] Loading expression data from: {input_file}") try: - sep = '\t' if input_file.endswith('.tsv') else ',' + sep = "\t" if input_file.endswith(".tsv") else "," expr_df = pd.read_csv(input_file, sep=sep, index_col=0) except Exception as e: + print(f"[ERROR] Failed to load file: {e}") # Fallback: try the other separator try: - sep = ',' if sep == '\t' else '\t' + sep = "," if sep == "\t" else "\t" expr_df = pd.read_csv(input_file, sep=sep, index_col=0) except Exception as e2: print(f"[ERROR] Failed to load file: {e2}") sys.exit(1) - + print(f" Expression matrix: {expr_df.shape[0]} genes × {expr_df.shape[1]} samples") - + # Check for missing values na_count = expr_df.isna().sum().sum() if na_count > 0: genes_with_na = expr_df.isna().any(axis=1).sum() print(f" [WARNING] Found {na_count} missing values in {genes_with_na} genes") - print(f" [INFO] Removing genes with missing data.") + print(" [INFO] Removing genes with missing data.") expr_df = expr_df.dropna() - print(f" [INFO] Retained {expr_df.shape[0]} genes ({genes_with_na} genes removed)") - + print( + f" [INFO] Retained {expr_df.shape[0]} genes ({genes_with_na} genes removed)" + ) + # Build network directly - print(f"\n[2/3] Building network...") + print("\n[2/3] Building network...") edges = build_network( - expr_df, - method=method, - min_correlation=min_correlation, - top_k=top_k + expr_df, method=method, min_correlation=min_correlation, top_k=top_k ) - + # Create dataframe network_df = pd.DataFrame(edges) - + if len(network_df) == 0: print("[WARNING] No edges found! Try lowering min_correlation threshold.") print("[ERROR] No edges in network! Exiting.") sys.exit(1) - + # Remove duplicate edges (A-B and B-A are the same) if remove_duplicates: print("\nRemoving duplicate edges...") - network_df['pair'] = network_df.apply( - lambda row: tuple(sorted([row['GeneA'], row['GeneB']])), - axis=1 + network_df["pair"] = network_df.apply( + lambda row: tuple(sorted([row["GeneA"], row["GeneB"]])), axis=1 ) original_len = len(network_df) - network_df = network_df.drop_duplicates(subset='pair').drop(columns='pair') + network_df = network_df.drop_duplicates(subset="pair").drop(columns="pair") print(f" Removed {original_len - len(network_df)} duplicate edges") - + # Save network (auto-detect format from extension) print(f"\n[3/3] Saving network to: {output_file}") - sep = '\t' if output_file.endswith('.tsv') else ',' + sep = "\t" if output_file.endswith(".tsv") else "," network_df.to_csv(output_file, sep=sep, index=False) - + # Print statistics print("\n" + "=" * 70) print("Network Generation Complete!") print("=" * 70) - print(f"\nNetwork Statistics:") + print("\nNetwork Statistics:") print(f" Total edges: {len(network_df):,}") print(f" Unique genes (GeneA): {network_df['GeneA'].nunique():,}") print(f" Unique genes (GeneB): {network_df['GeneB'].nunique():,}") - print(f" All unique genes: {len(set(network_df['GeneA']) | set(network_df['GeneB'])):,}") - print(f" Score range: [{network_df['Score'].min():.4f}, {network_df['Score'].max():.4f}]") + print( + f" All unique genes: {len(set(network_df['GeneA']) | set(network_df['GeneB'])):,}" + ) + print( + f" Score range: [{network_df['Score'].min():.4f}, {network_df['Score'].max():.4f}]" + ) print(f" Mean score: {network_df['Score'].mean():.4f}") print(f" Median score: {network_df['Score'].median():.4f}") - + # Show sample - print(f"\nSample edges (first 5):") + print("\nSample edges (first 5):") print(network_df.head().to_string(index=False)) - + print(f"\n{'=' * 70}") print("Usage with Flexynesis:") print(f"{'=' * 70}") - print(f"\nflexynesis --data_path \\") - print(f" --model_class GNN \\") - print(f" --gnn_conv_type GCN \\") - print(f" --target_variables \\") - print(f" --data_types gex,cnv \\") + print("\nflexynesis --data_path \\") + print(" --model_class GNN \\") + print(" --gnn_conv_type GCN \\") + print(" --target_variables \\") + print(" --data_types gex,cnv \\") print(f" --user_graph {output_file}") print() def main(): parser = argparse.ArgumentParser( - description='Generate gene co-expression network from expression data', + description="Generate gene co-expression network from expression data", formatter_class=argparse.RawDescriptionHelpFormatter, epilog=""" Examples: @@ -258,7 +271,7 @@ def main(): python generate_coexpression_network.py \\ --input expression_data.csv \\ --output coexpression_network.csv - + # Use Pearson correlation with stricter threshold python generate_coexpression_network.py \\ --input expression_data.csv \\ @@ -266,7 +279,7 @@ def main(): --method pearson \\ --min_correlation 0.5 \\ --top_k 15 - + # More permissive network (more edges) python generate_coexpression_network.py \\ --input expression_data.csv \\ @@ -277,62 +290,67 @@ def main(): Input file format: CSV/TSV file with genes as rows, samples as columns First column should contain gene names/IDs - + Example: ,Sample1,Sample2,Sample3 TP53,5.2,6.1,5.8 BRCA1,7.3,6.9,7.1 EGFR,8.1,8.3,8.0 -""" +""", ) - + parser.add_argument( - '--input', '-i', + "--input", + "-i", required=True, - help='Input gene expression file (CSV/TSV supported)' + help="Input gene expression file (CSV/TSV supported)", ) - + parser.add_argument( - '--output', '-o', + "--output", + "-o", required=True, - help='Output network file (CSV/TSV supported)' + help="Output network file (CSV/TSV supported)", ) - + parser.add_argument( - '--method', '-m', - default='spearman', - choices=['spearman', 'pearson'], - help='Correlation method (default: spearman)' + "--method", + "-m", + default="spearman", + choices=["spearman", "pearson"], + help="Correlation method (default: spearman)", ) - + parser.add_argument( - '--min_correlation', '-c', + "--min_correlation", + "-c", type=float, default=0.3, - help='Minimum absolute correlation to include (default: 0.3)' + help="Minimum absolute correlation to include (default: 0.3)", ) - + parser.add_argument( - '--top_k', '-k', + "--top_k", + "-k", type=int, default=10, - help='Keep top K neighbors per gene (default: 10)' + help="Keep top K neighbors per gene (default: 10)", ) - + parser.add_argument( - '--keep_self_loops', - action='store_true', - help='Keep self-correlations (gene-gene) - not recommended' + "--keep_self_loops", + action="store_true", + help="Keep self-correlations (gene-gene) - not recommended", ) - + parser.add_argument( - '--keep_duplicates', - action='store_true', - help='Keep duplicate edges (A-B and B-A) - not recommended' + "--keep_duplicates", + action="store_true", + help="Keep duplicate edges (A-B and B-A) - not recommended", ) - + args = parser.parse_args() - + # Generate network generate_coexpression_network( input_file=args.input, @@ -341,7 +359,7 @@ def main(): min_correlation=args.min_correlation, top_k=args.top_k, remove_self_loops=not args.keep_self_loops, - remove_duplicates=not args.keep_duplicates + remove_duplicates=not args.keep_duplicates, ) diff --git a/flexynesis/inference.py b/flexynesis/inference.py index 5b1dde5f..13bc3740 100644 --- a/flexynesis/inference.py +++ b/flexynesis/inference.py @@ -8,23 +8,26 @@ from types import SimpleNamespace import numpy as np -import joblib import torch from safetensors.torch import load_file - MODEL_REGISTRY = { - "DirectPred": ("flexynesis.models.direct_pred", "DirectPred"), - "supervised_vae": ("flexynesis.models.supervised_vae", "supervised_vae"), - "CrossModalPred": ("flexynesis.models.crossmodal_pred","CrossModalPred"), - "MultiTripletNetwork": ("flexynesis.models.triplet_encoder","MultiTripletNetwork"), - "GNN": ("flexynesis.models.gnn_early", "GNN"), + "DirectPred": ("flexynesis.models.direct_pred", "DirectPred"), + "supervised_vae": ("flexynesis.models.supervised_vae", "supervised_vae"), + "CrossModalPred": ("flexynesis.models.crossmodal_pred", "CrossModalPred"), + "MultiTripletNetwork": ( + "flexynesis.models.triplet_encoder", + "MultiTripletNetwork", + ), + "GNN": ("flexynesis.models.gnn_early", "GNN"), } + def check_model_type(file_path): - import struct import json - with open(file_path, 'rb') as f: + import struct + + with open(file_path, "rb") as f: header_start = f.read(8) if len(header_start) < 8: @@ -32,12 +35,12 @@ def check_model_type(file_path): # 1. Try SafeTensors check first try: - header_size = struct.unpack(' 0): + + if ( + platform.system() == "Darwin" + and os.environ.get("GITHUB_ACTIONS") == "true" + and num_workers > 0 + ): import warnings + warnings.warn( f"Detected macOS in GHA: Setting num_workers=0 (was {num_workers}) to avoid permission error in GHA.", - UserWarning + UserWarning, ) num_workers = 0 self.num_workers = num_workers - - self.DataLoader = torch.utils.data.DataLoader # use torch data loader by default - - if self.model_class.__name__ == 'MultiTripletNetwork': - self.loader_dataset = TripletMultiOmicDataset(self.dataset, self.target_variables[0]) + + self.DataLoader = ( + torch.utils.data.DataLoader + ) # use torch data loader by default + + if self.model_class.__name__ == "MultiTripletNetwork": + self.loader_dataset = TripletMultiOmicDataset( + self.dataset, self.target_variables[0] + ) # If config_path is provided, use it if config_path: @@ -135,55 +167,65 @@ def __init__(self, dataset, model_class, config_name, target_variables, if self.config_name in external_config: self.space = external_config[self.config_name] else: - raise ValueError(f"'{self.config_name}' not found in the provided config file.") + raise ValueError( + f"'{self.config_name}' not found in the provided config file." + ) else: if self.config_name in search_spaces: self.space = search_spaces[self.config_name] # get batch sizes (a function of dataset size) self.space.append(self.get_batch_space()) else: - raise ValueError(f"'{self.config_name}' not found in the default config.") + raise ValueError( + f"'{self.config_name}' not found in the default config." + ) - def get_batch_space(self, min_size = 32, max_size = 128): + def get_batch_space(self, min_size=32, max_size=128): m = int(np.log2(len(self.dataset) * 0.8)) st = int(np.log2(min_size)) end = int(np.log2(max_size)) if m < end: end = m - s = Categorical([np.power(2, x) for x in range(st, end+1)], name = 'batch_size') + s = Categorical([np.power(2, x) for x in range(st, end + 1)], name="batch_size") return s - - def setup_trainer(self, params, current_step, total_steps, full_train = False): + + def setup_trainer(self, params, current_step, total_steps, full_train=False): # Configure callbacks and trainer for the current fold mycallbacks = [] if self.plot_losses: - mycallbacks.append(LiveLossPlot(hyperparams=params, current_step=current_step, total_steps=total_steps)) + mycallbacks.append( + LiveLossPlot( + hyperparams=params, + current_step=current_step, + total_steps=total_steps, + ) + ) else: mycallbacks.append(self.progress_bar) - # when training on a full dataset; no cross-validation or no validation splits; + # when training on a full dataset; no cross-validation or no validation splits; # we don't do early stopping early_stop_callback = None - if self.early_stop_patience > 0 and full_train == False: + if self.early_stop_patience > 0 and not full_train: early_stop_callback = self.init_early_stopping() mycallbacks.append(early_stop_callback) - + trainer = pl.Trainer( - #deterministic = True, - #precision = '16-mixed', # mixed precision training - max_epochs=int(params['epochs']), - gradient_clip_val=1.0, - gradient_clip_algorithm='norm', + # deterministic = True, + # precision = '16-mixed', # mixed precision training + max_epochs=int(params["epochs"]), + gradient_clip_val=1.0, + gradient_clip_algorithm="norm", log_every_n_steps=5, callbacks=mycallbacks, default_root_dir="./", logger=False, enable_checkpointing=False, devices=1, - accelerator=self.device_type + accelerator=self.device_type, ) return trainer, early_stop_callback - - def objective(self, params, current_step, total_steps, full_train = False): + + def objective(self, params, current_step, total_steps, full_train=False): # Unpack or construct specific model arguments model_args = { "config": params, @@ -195,19 +237,26 @@ def objective(self, params, current_step, total_steps, full_train = False): "use_loss_weighting": self.use_loss_weighting, "device_type": self.device_type, } - - if self.model_class.__name__ == 'GNN': - model_args['gnn_conv_type'] = self.gnn_conv_type - if self.model_class.__name__ == 'CrossModalPred': - model_args['input_layers'] = self.input_layers - model_args['output_layers'] = self.output_layers + + if self.model_class.__name__ == "GNN": + model_args["gnn_conv_type"] = self.gnn_conv_type + if self.model_class.__name__ == "CrossModalPred": + model_args["input_layers"] = self.input_layers + model_args["output_layers"] = self.output_layers if full_train: # Train on the full dataset - full_loader = self.DataLoader(self.loader_dataset, batch_size=int(params['batch_size']), - shuffle=True, pin_memory=True, drop_last=True) + full_loader = self.DataLoader( + self.loader_dataset, + batch_size=int(params["batch_size"]), + shuffle=True, + pin_memory=True, + drop_last=True, + ) model = self.model_class(**model_args) - trainer, _ = self.setup_trainer(params, current_step, total_steps, full_train = True) + trainer, _ = self.setup_trainer( + params, current_step, total_steps, full_train=True + ) trainer.fit(model, train_dataloaders=full_loader) return model # Return the trained model @@ -215,39 +264,64 @@ def objective(self, params, current_step, total_steps, full_train = False): validation_losses = [] epochs = [] - if self.use_cv: # if the user asks for cross-validation + if self.use_cv: # if the user asks for cross-validation kf = KFold(n_splits=self.n_splits, shuffle=True) split_iterator = kf.split(self.loader_dataset) - else: # otherwise do a single train/validation split + else: # otherwise do a single train/validation split # Compute the number of samples for training based on the ratio num_val = int(len(self.loader_dataset) * self.val_size) num_train = len(self.loader_dataset) - num_val - train_subset, val_subset = random_split(self.loader_dataset, [num_train, num_val]) + train_subset, val_subset = random_split( + self.loader_dataset, [num_train, num_val] + ) # create single split format similar to KFold - split_iterator = [(list(train_subset.indices), list(val_subset.indices))] + split_iterator = [ + (list(train_subset.indices), list(val_subset.indices)) + ] i = 1 - model = None # save the model if not using cross-validation + model = None # save the model if not using cross-validation for train_index, val_index in split_iterator: - print(f"[INFO] {'training cross-validation fold' if self.use_cv else 'training validation split'} {i}") + print( + f"[INFO] {'training cross-validation fold' if self.use_cv else 'training validation split'} {i}" + ) train_subset = torch.utils.data.Subset(self.loader_dataset, train_index) val_subset = torch.utils.data.Subset(self.loader_dataset, val_index) - train_loader = self.DataLoader(train_subset, batch_size=int(params['batch_size']), - pin_memory=True, shuffle=True, drop_last=True, num_workers = self.num_workers, prefetch_factor = None, - persistent_workers = self.num_workers > 0) - val_loader = self.DataLoader(val_subset, batch_size=int(params['batch_size']), - pin_memory=True, shuffle=False, num_workers = self.num_workers, prefetch_factor = None, - persistent_workers = self.num_workers > 0) + train_loader = self.DataLoader( + train_subset, + batch_size=int(params["batch_size"]), + pin_memory=True, + shuffle=True, + drop_last=True, + num_workers=self.num_workers, + prefetch_factor=None, + persistent_workers=self.num_workers > 0, + ) + val_loader = self.DataLoader( + val_subset, + batch_size=int(params["batch_size"]), + pin_memory=True, + shuffle=False, + num_workers=self.num_workers, + prefetch_factor=None, + persistent_workers=self.num_workers > 0, + ) model = self.model_class(**model_args) - trainer, early_stop_callback = self.setup_trainer(params, current_step, total_steps) + trainer, early_stop_callback = self.setup_trainer( + params, current_step, total_steps + ) print(f"[INFO] hpo config:{params}") - trainer.fit(model, train_dataloaders=train_loader, val_dataloaders=val_loader) + trainer.fit( + model, + train_dataloaders=train_loader, + val_dataloaders=val_loader, + ) if early_stop_callback.stopped_epoch: epochs.append(early_stop_callback.stopped_epoch) else: - epochs.append(int(params['epochs'])) + epochs.append(int(params["epochs"])) validation_result = trainer.validate(model, dataloaders=val_loader) - val_loss = validation_result[0]['val_loss'] + val_loss = validation_result[0]["val_loss"] validation_losses.append(val_loss) i += 1 if not self.use_cv: @@ -256,27 +330,41 @@ def objective(self, params, current_step, total_steps, full_train = False): # Calculate average validation loss across all folds avg_val_loss = np.mean(validation_losses) avg_epochs = int(np.mean(epochs)) - return avg_val_loss, avg_epochs, model - - def perform_tuning(self, hpo_patience = 0): - opt = Optimizer(dimensions=self.space, n_initial_points=10, acq_func="gp_hedge", acq_optimizer="auto") + return avg_val_loss, avg_epochs, model + + def perform_tuning(self, hpo_patience=0): + opt = Optimizer( + dimensions=self.space, + n_initial_points=10, + acq_func="gp_hedge", + acq_optimizer="auto", + ) best_loss = np.inf best_params = None best_epochs = 0 best_model = None - # keep track of the streak of HPO iterations without improvement + # keep track of the streak of HPO iterations without improvement no_improvement_count = 0 - with tqdm(total=self.n_iter, desc='Tuning Progress') as pbar: + with tqdm(total=self.n_iter, desc="Tuning Progress") as pbar: for i in range(self.n_iter): np.int = int # Ensure int type is correctly handled suggested_params_list = opt.ask() - suggested_params_dict = {param.name: value for param, value in zip(self.space, suggested_params_list)} - loss, avg_epochs, model = self.objective(suggested_params_dict, current_step=i+1, total_steps=self.n_iter) + suggested_params_dict = { + param.name: value + for param, value in zip(self.space, suggested_params_list) + } + loss, avg_epochs, model = self.objective( + suggested_params_dict, + current_step=i + 1, + total_steps=self.n_iter, + ) if self.use_cv: - print(f"[INFO] average 5-fold cross-validation loss {loss} for params: {suggested_params_dict}") + print( + f"[INFO] average 5-fold cross-validation loss {loss} for params: {suggested_params_dict}" + ) opt.tell(suggested_params_list, loss) if loss < best_loss: @@ -289,34 +377,53 @@ def perform_tuning(self, hpo_patience = 0): no_improvement_count += 1 # Increment the no improvement counter # Print result of each iteration - pbar.set_postfix({'Iteration': i+1, 'Best Loss': best_loss}) + pbar.set_postfix({"Iteration": i + 1, "Best Loss": best_loss}) pbar.update(1) # Early stopping condition if no_improvement_count >= hpo_patience & hpo_patience > 0: - print(f"No improvement in best loss for {hpo_patience} iterations, stopping hyperparameter optimisation early.") + print( + f"No improvement in best loss for {hpo_patience} iterations, " + "stopping hyperparameter optimisation early." + ) break # Break out of the loop - best_params_dict = {param.name: value for param, value in zip(self.space, best_params)} if best_params else None - print(f"[INFO] current best val loss: {best_loss}; best params: {best_params_dict} since {no_improvement_count} hpo iterations") - + best_params_dict = ( + {param.name: value for param, value in zip(self.space, best_params)} + if best_params + else None + ) + print( + f"[INFO] current best val loss: {best_loss}; best params: " + f"{best_params_dict} since {no_improvement_count} hpo iterations" + ) # Convert best parameters from list to dictionary and include epochs - best_params_dict = {param.name: value for param, value in zip(self.space, best_params)} - best_params_dict['epochs'] = best_epochs + best_params_dict = { + param.name: value for param, value in zip(self.space, best_params) + } + best_params_dict["epochs"] = best_epochs if self.use_cv: # Build a final model based on best parameters if using cross-validation - print(f"[INFO] Building a final model using best params: {best_params_dict}") - best_model = self.objective(best_params_dict, current_step=0, total_steps=1, full_train=True) + print( + f"[INFO] Building a final model using best params: {best_params_dict}" + ) + best_model = self.objective( + best_params_dict, + current_step=0, + total_steps=1, + full_train=True, + ) + + return best_model, best_params_dict - return best_model, best_params_dict def init_early_stopping(self): """Initialize the early stopping callback.""" return EarlyStopping( - monitor='val_loss', + monitor="val_loss", patience=self.early_stop_patience, verbose=False, - mode='min' + mode="min", ) def load_and_convert_config(self, config_path): @@ -325,8 +432,8 @@ def load_and_convert_config(self, config_path): raise ValueError(f"Config file '{config_path}' doesn't exist.") # Read the config file - if config_path.endswith('.yaml') or config_path.endswith('.yml'): - with open(config_path, 'r') as file: + if config_path.endswith(".yaml") or config_path.endswith(".yml"): + with open(config_path, "r") as file: loaded_config = yaml.safe_load(file) else: raise ValueError("Unsupported file format. Use .yaml or .yml") @@ -349,18 +456,13 @@ def load_and_convert_config(self, config_path): return search_space_user -from torch.utils.data import DataLoader, random_split -from sklearn.model_selection import KFold -import numpy as np -import random, copy, logging - class FineTuner(pl.LightningModule): """ - FineTuner class is designed for fine-tuning trained flexynesis models with flexible control over parameters such as + FineTuner class is designed for fine-tuning trained flexynesis models with flexible control over parameters such as learning rates and component freezing, utilizing cross-validation to optimize generalization. - This class allows the application of different configuration strategies to either freeze or unfreeze specific - model components, while also exploring different learning rates to find the optimal setting. + This class allows the application of different configuration strategies to either freeze or unfreeze specific + model components, while also exploring different learning rates to find the optimal setting. It carries out cross-validation to find the best combination of parameter freezing strategies and learning rates. Attributes: @@ -382,37 +484,59 @@ class FineTuner(pl.LightningModule): run_experiments(): Executes the finetuning process across all configurations and learning rates, evaluates using cross-validation, and selects the best configuration based on validation loss. """ - def __init__(self, model, dataset, n_splits=5, batch_size=32, learning_rates=None, max_epoch = 50, freeze_configs = None): + + def __init__( + self, + model, + dataset, + n_splits=5, + batch_size=32, + learning_rates=None, + max_epoch=50, + freeze_configs=None, + ): super().__init__() logging.getLogger("lightning.pytorch").setLevel(logging.ERROR) - self.original_model = model + self.original_model = model self.dataset = dataset # Use the entire dataset self.n_splits = n_splits self.batch_size = batch_size self.kfold = KFold(n_splits=self.n_splits, shuffle=True) - self.learning_rates = learning_rates if learning_rates else [model.config['lr'], model.config['lr']/10, model.config['lr']/100] - self.folds_data = list(self.kfold.split(np.arange(len(self.dataset)))) + self.learning_rates = ( + learning_rates + if learning_rates + else [ + model.config["lr"], + model.config["lr"] / 10, + model.config["lr"] / 100, + ] + ) + self.folds_data = list(self.kfold.split(np.arange(len(self.dataset)))) self.max_epoch = max_epoch - self.freeze_configs = freeze_configs if freeze_configs else [ - {'encoders': True, 'supervisors': False}, - {'encoders': False, 'supervisors': True}, - {'encoders': False, 'supervisors': False} - ] - - if model.__class__.__name__ == 'MultiTripletNetwork': + self.freeze_configs = ( + freeze_configs + if freeze_configs + else [ + {"encoders": True, "supervisors": False}, + {"encoders": False, "supervisors": True}, + {"encoders": False, "supervisors": False}, + ] + ) + + if model.__class__.__name__ == "MultiTripletNetwork": # modify dataset structure to accommodate TripletNetworks self.dataset = TripletMultiOmicDataset(dataset, model.main_var) - + def apply_freeze_config(self, config): # Freeze or unfreeze encoders for encoder in self.model.encoders: for param in encoder.parameters(): - param.requires_grad = not config['encoders'] - + param.requires_grad = not config["encoders"] + # Freeze or unfreeze supervisors for mlp in self.model.MLPs.values(): for param in mlp.parameters(): - param.requires_grad = not config['supervisors'] + param.requires_grad = not config["supervisors"] def train_dataloader(self): # Override to load data for the current fold @@ -431,62 +555,103 @@ def training_step(self, batch, batch_idx): def validation_step(self, batch, batch_idx): # Call the model's validation step without logging - val_loss = self.model.validation_step(batch, batch_idx, log=False) + val_loss = self.model.validation_step(batch, batch_idx, log=False) self.log("val_loss", val_loss, on_epoch=True, prog_bar=True) return val_loss - + def configure_optimizers(self): - return torch.optim.Adam(filter(lambda p: p.requires_grad, self.model.parameters()), lr=self.learning_rate) + return torch.optim.Adam( + filter(lambda p: p.requires_grad, self.model.parameters()), + lr=self.learning_rate, + ) def run_experiments(self): val_loss_results = [] for lr in self.learning_rates: for config in self.freeze_configs: fold_losses = [] - epochs = [] # record how many epochs the training happened + epochs = [] # record how many epochs the training happened for fold in range(self.n_splits): - model_copy = copy.deepcopy(self.original_model) # Deep copy the model for each fold + model_copy = copy.deepcopy( + self.original_model + ) # Deep copy the model for each fold self.model = model_copy - self.apply_freeze_config(config) # try freezing different components + self.apply_freeze_config( + config + ) # try freezing different components self.current_fold = fold self.learning_rate = lr early_stopping = EarlyStopping( - monitor='val_loss', + monitor="val_loss", patience=3, verbose=False, - mode='min' + mode="min", + ) + trainer = pl.Trainer( + max_epochs=self.max_epoch, + devices=1, + accelerator="auto", + logger=False, + enable_checkpointing=False, + enable_progress_bar=False, + enable_model_summary=False, + callbacks=[early_stopping], + ) + trainer.fit( + self, + train_dataloaders=self.train_dataloader(), + val_dataloaders=self.val_dataloader(), ) - trainer = pl.Trainer(max_epochs=self.max_epoch, devices=1, accelerator='auto', logger=False, enable_checkpointing=False, - enable_progress_bar = False, enable_model_summary=False, callbacks=[early_stopping]) - trainer.fit(self, train_dataloaders=self.train_dataloader(), val_dataloaders=self.val_dataloader()) stopped_epoch = early_stopping.stopped_epoch - val_loss = trainer.validate(self, dataloaders = self.val_dataloader(), verbose = False) - fold_losses.append(val_loss[0]['val_loss']) # Adjust based on your validation output format + val_loss = trainer.validate( + self, dataloaders=self.val_dataloader(), verbose=False + ) + fold_losses.append( + val_loss[0]["val_loss"] + ) # Adjust based on your validation output format epochs.append(stopped_epoch) - #print(f"[INFO] Finetuning ... training fold: {fold}, learning rate: {lr}, val_loss: {val_loss}, freeze {config}") + # print( + # f"[INFO] Finetuning ... training fold: {fold}, learning rate: {lr}, " + # f"val_loss: {val_loss}, freeze {config}" + # ) avg_val_loss = np.mean(fold_losses) avg_epochs = int(np.mean(epochs)) - print(f"[INFO] average 5-fold cross-validation loss {avg_val_loss} for learning rate: {lr} freeze {config}, average epochs {avg_epochs}") - val_loss_results.append({'learning_rate': lr, 'average_val_loss': avg_val_loss, 'freeze': config, 'epochs': avg_epochs}) + print( + f"[INFO] average 5-fold cross-validation loss {avg_val_loss} " + f"for learning rate: {lr} freeze {config}, " + f"average epochs {avg_epochs}" + ) + val_loss_results.append( + { + "learning_rate": lr, + "average_val_loss": avg_val_loss, + "freeze": config, + "epochs": avg_epochs, + } + ) # Find the best configuration based on validation loss - best_config = min(val_loss_results, key=lambda x: x['average_val_loss']) - print(f"Best learning rate: {best_config['learning_rate']} and freeze {best_config['freeze']}", - f"with average validation loss: {best_config['average_val_loss']} and average epochs: {best_config['epochs']}") + best_config = min(val_loss_results, key=lambda x: x["average_val_loss"]) + print( + f"Best learning rate: {best_config['learning_rate']} and freeze {best_config['freeze']}", + f"with average validation loss: {best_config['average_val_loss']} " + f"and average epochs: {best_config['epochs']}", + ) # build a final model using the best setup on all samples final_model = copy.deepcopy(self.model) self.model = final_model - self.learning_rate = best_config['learning_rate'] - self.apply_freeze_config(best_config['freeze']) + self.learning_rate = best_config["learning_rate"] + self.apply_freeze_config(best_config["freeze"]) dl = DataLoader(self.dataset, batch_size=self.batch_size, shuffle=True) - final_trainer = pl.Trainer(max_epochs=best_config['epochs'], devices=1, accelerator='auto', logger=False, enable_checkpointing=False) + final_trainer = pl.Trainer( + max_epochs=best_config["epochs"], + devices=1, + accelerator="auto", + logger=False, + enable_checkpointing=False, + ) final_trainer.fit(self, train_dataloaders=dl) - - -import matplotlib.pyplot as plt -from IPython.display import display -from lightning import Callback class LiveLossPlot(Callback): @@ -495,7 +660,7 @@ class LiveLossPlot(Callback): This class is a PyTorch Lightning callback that plots training loss and other metrics live as the model trains. It is especially useful for tracking the progress of hyperparameter optimization (HPO) steps. - + Compatible with papermill/nbclient execution by using display handles for proper output tracking. Attributes: @@ -510,6 +675,7 @@ class LiveLossPlot(Callback): on_train_epoch_end(trainer, pl_module): Updates and plots the loss after each training epoch. plot_losses(): Renders the loss plot with the current training metrics. """ + def __init__(self, hyperparams, current_step, total_steps, figsize=(8, 6)): super().__init__() self.hyperparams = hyperparams @@ -547,7 +713,9 @@ def plot_losses(self): ax.plot(epochs_range, losses_to_plot, label=key) - hyperparams_str = ', '.join(f"{key}={value}" for key, value in self.hyperparams.items()) + hyperparams_str = ", ".join( + f"{key}={value}" for key, value in self.hyperparams.items() + ) title = f"HPO Step={self.current_step} out of {self.total_steps}\n({hyperparams_str})" ax.set_title(title) @@ -555,12 +723,12 @@ def plot_losses(self): ax.set_ylabel("Loss") ax.legend() fig.tight_layout() - + # Use display with a handle for proper papermill/nbclient compatibility # This avoids the display_id assertion error if self.display_handle is None: self.display_handle = display(fig, display_id=True) else: self.display_handle.update(fig) - + plt.close(fig) diff --git a/flexynesis/models/__init__.py b/flexynesis/models/__init__.py index 2fad7361..eb677382 100644 --- a/flexynesis/models/__init__.py +++ b/flexynesis/models/__init__.py @@ -1,6 +1,13 @@ +from .crossmodal_pred import CrossModalPred from .direct_pred import DirectPred +from .gnn_early import GNN from .supervised_vae import supervised_vae from .triplet_encoder import MultiTripletNetwork -from .crossmodal_pred import CrossModalPred -from .gnn_early import GNN -__all__ = ["DirectPred", "supervised_vae", "MultiTripletNetwork", "CrossModalPred", "GNN"] + +__all__ = [ + "DirectPred", + "supervised_vae", + "MultiTripletNetwork", + "CrossModalPred", + "GNN", +] diff --git a/flexynesis/models/crossmodal_pred.py b/flexynesis/models/crossmodal_pred.py index 1c004e50..0ddef098 100644 --- a/flexynesis/models/crossmodal_pred.py +++ b/flexynesis/models/crossmodal_pred.py @@ -1,38 +1,45 @@ +import itertools + +import lightning as pl +import numpy as np +import pandas as pd import torch -import itertools +from captum.attr import GradientShap, IntegratedGradients from torch import nn from torch.nn import functional as F -from torch.utils.data import Dataset, DataLoader, random_split - -import pandas as pd -import numpy as np - -import lightning as pl -from scipy import stats +from torch.utils.data import DataLoader -from captum.attr import IntegratedGradients, GradientShap - -from ..modules import * -from ..utils import to_device_safe +from ..modules import MLP, Decoder, Encoder, cox_ph_loss class CrossModalPred(pl.LightningModule): """ - A Cross-Modality Prediction Architecture that encodes user-specified input data modalities and + A Cross-Modality Prediction Architecture that encodes user-specified input data modalities and tries to reconstruct user-specificed output data modalities. In the case where input/output data modalities - are the same, this behaves like an auto-encoder. - The network also can be connected to one or more MLPs for outcome variable prediction. - + are the same, this behaves like an auto-encoder. + The network also can be connected to one or more MLPs for outcome variable prediction. + dataset: dictionary of data matrices - input_layers: which data modalities from `dataset` to encode (use a subset of keys from `dataset`) - output_layers: which data modalities are aimed to be reconsructed via decoders (use a subset of keys from `dataset`). - + input_layers: which data modalities from `dataset` to encode (use a + subset of keys from `dataset`) + output_layers: which data modalities are aimed to be reconsructed via + decoders (use a subset of keys from `dataset`). + """ - def __init__(self, config, dataset, target_variables = None, batch_variables = None, - surv_event_var = None, surv_time_var = None, - input_layers = None, output_layers = None, - use_loss_weighting = True, - device_type = None): + + def __init__( + self, + config, + dataset, + target_variables=None, + batch_variables=None, + surv_event_var=None, + surv_time_var=None, + input_layers=None, + output_layers=None, + use_loss_weighting=True, + device_type=None, + ): super(CrossModalPred, self).__init__() self.config = config self.target_variables = target_variables @@ -43,56 +50,86 @@ def __init__(self, config, dataset, target_variables = None, batch_variables = if self.surv_event_var is not None and self.surv_time_var is not None: self.target_variables = self.target_variables + [self.surv_event_var] self.batch_variables = batch_variables - self.variables = self.target_variables + self.batch_variables if self.batch_variables else self.target_variables - self.variable_types = dataset.variable_types + self.variables = ( + self.target_variables + self.batch_variables + if self.batch_variables + else self.target_variables + ) + self.variable_types = dataset.variable_types self.ann = dataset.ann - - self.input_layers = input_layers if input_layers else list(dataset.dat.keys()) - self.output_layers = output_layers if output_layers else list(dataset.dat.keys()) - + + self.input_layers = input_layers if input_layers else list(dataset.dat.keys()) + self.output_layers = ( + output_layers if output_layers else list(dataset.dat.keys()) + ) + self.feature_importances = {} - + self.device_type = device_type self.use_loss_weighting = use_loss_weighting - + if self.use_loss_weighting: # Initialize log variance parameters for uncertainty weighting self.log_vars = nn.ParameterDict() - for loss_type in itertools.chain(self.variables, ['mmd_loss']): + for loss_type in itertools.chain(self.variables, ["mmd_loss"]): self.log_vars[loss_type] = nn.Parameter(torch.zeros(1)) - - # create a list of Encoder instances for separately encoding each input omics layer - input_dims = [len(dataset.features[self.input_layers[i]]) for i in range(len(self.input_layers))] - self.encoders = nn.ModuleList([Encoder(input_dims[i], - # define hidden_dim size as a factor of input_dim - [int(input_dims[i] * config['hidden_dim_factor'])], - config['latent_dim']) - for i in range(len(self.input_layers))]) - + + # create a list of Encoder instances for separately encoding each input omics layer + input_dims = [ + len(dataset.features[self.input_layers[i]]) + for i in range(len(self.input_layers)) + ] + self.encoders = nn.ModuleList( + [ + Encoder( + input_dims[i], + # define hidden_dim size as a factor of input_dim + [int(input_dims[i] * config["hidden_dim_factor"])], + config["latent_dim"], + ) + for i in range(len(self.input_layers)) + ] + ) + # Fully connected layers for concatenated means and log_vars - self.FC_mean = nn.Linear(len(self.input_layers) * config['latent_dim'], config['latent_dim']) - self.FC_log_var = nn.Linear(len(self.input_layers) * config['latent_dim'], config['latent_dim']) - - # list of decoders to decode the latent layer into the target/output layers - output_dims = [len(dataset.features[self.output_layers[i]]) for i in range(len(self.output_layers))] - self.decoders = nn.ModuleList([Decoder(config['latent_dim'], - [int(output_dims[i] * config['hidden_dim_factor'])], - output_dims[i]) - for i in range(len(self.output_layers))]) + self.FC_mean = nn.Linear( + len(self.input_layers) * config["latent_dim"], config["latent_dim"] + ) + self.FC_log_var = nn.Linear( + len(self.input_layers) * config["latent_dim"], config["latent_dim"] + ) + + # list of decoders to decode the latent layer into the target/output layers + output_dims = [ + len(dataset.features[self.output_layers[i]]) + for i in range(len(self.output_layers)) + ] + self.decoders = nn.ModuleList( + [ + Decoder( + config["latent_dim"], + [int(output_dims[i] * config["hidden_dim_factor"])], + output_dims[i], + ) + for i in range(len(self.output_layers)) + ] + ) # define supervisor heads # using ModuleDict to store multiple MLPs - self.MLPs = nn.ModuleDict() + self.MLPs = nn.ModuleDict() for var in self.variables: - if self.variable_types[var] == 'numerical': + if self.variable_types[var] == "numerical": num_class = 1 else: num_class = len(np.unique(self.ann[var])) - self.MLPs[var] = MLP(input_dim = config['latent_dim'], - hidden_dim = config['supervisor_hidden_dim'], - output_dim = num_class) - + self.MLPs[var] = MLP( + input_dim=config["latent_dim"], + hidden_dim=config["supervisor_hidden_dim"], + output_dim=num_class, + ) + def multi_encoder(self, x_list): """ Encode each input matrix separately using the corresponding Encoder. @@ -117,7 +154,7 @@ def multi_encoder(self, x_list): mean = self.FC_mean(torch.cat(means, dim=1)) log_var = self.FC_log_var(torch.cat(log_vars, dim=1)) return mean, log_var - + def forward(self, x_list_input): """ Forward pass through the model. @@ -134,20 +171,20 @@ def forward(self, x_list_input): - y_pred (torch.Tensor): Predicted output. """ mean, log_var = self.multi_encoder(x_list_input) - + # generate latent layer z = self.reparameterization(mean, log_var) # decode the latent space to target output layer(s) x_hat_list = [self.decoders[i](z) for i in range(len(self.output_layers))] - #run the supervisor heads using the latent layer as input + # run the supervisor heads using the latent layer as input outputs = {} for var, mlp in self.MLPs.items(): outputs[var] = mlp(z) - + return x_hat_list, z, mean, log_var, outputs - + def reparameterization(self, mean, var): """ Reparameterize the mean and variance values. @@ -159,10 +196,10 @@ def reparameterization(self, mean, var): Returns: torch.Tensor: Latent representation. """ - epsilon = torch.randn_like(var) - z = mean + var*epsilon + epsilon = torch.randn_like(var) + z = mean + var * epsilon return z - + def configure_optimizers(self): """ Configure the optimizer for the model. @@ -170,63 +207,69 @@ def configure_optimizers(self): Returns: torch.optim.Adam: Adam optimizer with learning rate 1e-3. """ - optimizer = torch.optim.Adam(self.parameters(), lr=self.config['lr']) + optimizer = torch.optim.Adam(self.parameters(), lr=self.config["lr"]) return optimizer - + def compute_loss(self, var, y, y_hat): """ Computes the loss for a specific variable based on whether the variable is numerical or categorical. Handles missing labels by excluding them from the loss calculation. - + Args: var (str): The name of the variable for which the loss is being calculated. y (torch.Tensor): The true labels or values for the variable. y_hat (torch.Tensor): The predicted labels or values output by the model. - + Returns: torch.Tensor: The calculated loss tensor for the variable. If there are no valid labels or values to compute the loss (all are missing), returns a zero loss tensor with gradient enabled. - + The method first checks the type of the variable (`var`) from `variable_types`. If the variable is numerical, it computes the mean squared error loss. For categorical variables, it calculates the cross-entropy loss. The method ensures to ignore any instances where the labels are missing (NaN for numerical or -1 for categorical as assumed missing value encoding) when calculating the loss. """ - if self.variable_types[var] == 'numerical': + if self.variable_types[var] == "numerical": # Ignore instances with missing labels for numerical variables valid_indices = ~torch.isnan(y) - if valid_indices.sum() > 0: # only calculate loss if there are valid targets + if ( + valid_indices.sum() > 0 + ): # only calculate loss if there are valid targets y_hat = y_hat[valid_indices] y = y[valid_indices] loss = F.mse_loss(torch.flatten(y_hat), y.float()) else: - loss = torch.tensor(0.0, device=y_hat.device, requires_grad=True) # if no valid labels, set loss to 0 + loss = torch.tensor( + 0.0, device=y_hat.device, requires_grad=True + ) # if no valid labels, set loss to 0 else: # Ignore instances with missing labels for categorical variables # Assuming that missing values were encoded as -1 valid_indices = (y != -1) & (~torch.isnan(y)) - if valid_indices.sum() > 0: # only calculate loss if there are valid targets + if ( + valid_indices.sum() > 0 + ): # only calculate loss if there are valid targets y_hat = y_hat[valid_indices] y = y[valid_indices] loss = F.cross_entropy(y_hat, y.long()) - else: + else: loss = torch.tensor(0.0, device=y_hat.device, requires_grad=True) return loss - + def compute_total_loss(self, losses): """ Computes the total loss from a dictionary of individual losses. This method can compute either weighted or unweighted total loss based on the model configuration. If loss weighting is enabled and there are multiple loss components, it uses uncertainty-based weighting. See Kendall A. et al, https://arxiv.org/abs/1705.07115. - + Args: losses (dict of torch.Tensor): A dictionary where each key is a variable name and each value is the loss tensor associated with that variable. - + Returns: torch.Tensor: The total loss computed across all inputs, either weighted or unweighted. - + The method checks if loss weighting is used (`use_loss_weighting`) and if there are multiple losses to weight. If so, it computes the weighted sum of losses, where the weight involves the exponential of the negative log variance (acting as precision) associated with each loss, @@ -235,118 +278,137 @@ def compute_total_loss(self, losses): loss component, it sums up the losses directly. """ if self.use_loss_weighting and len(losses) > 1: - # Compute weighted loss for each loss + # Compute weighted loss for each loss # Weighted loss = precision * loss + log-variance - total_loss = sum(torch.exp(-self.log_vars[name]) * loss + self.log_vars[name] for name, loss in losses.items()) + total_loss = sum( + torch.exp(-self.log_vars[name]) * loss + self.log_vars[name] + for name, loss in losses.items() + ) else: # Compute unweighted total loss total_loss = sum(losses.values()) return total_loss - - def training_step(self, train_batch, batch_idx, log = True): + def training_step(self, train_batch, batch_idx, log=True): """ Executes one training step using a single batch of data from a cross-modality prediction model. - + Args: train_batch (tuple): The batch data containing input features and target labels. batch_idx (int): The index of the current batch. log (bool, optional): Flag to determine if logging should occur at each step. Defaults to True. - + Returns: - torch.Tensor: The total loss for the current training batch, combining MMD loss for latent space regularization, - reconstruction losses for each output layer, and losses from supervisor heads for specified target variables. - - This method processes the batch by encoding input features from specified layers, decoding them to reconstruct the - output layers, and calculating the Maximum Mean Discrepancy (MMD) loss for latent space regularization. It computes - the reconstruction loss for each target/output layer. Additional losses are computed for other target variables in the - dataset, particularly handling survival analysis if applicable. All losses are aggregated to compute a total loss, - which is logged and returned. + torch.Tensor: The total loss for the current training batch, + combining MMD loss for latent space regularization, + reconstruction losses for each output layer, and losses from + supervisor heads for specified target variables. + + This method processes the batch by encoding input features from + specified layers, decoding them to reconstruct the output layers, and + calculating the Maximum Mean Discrepancy (MMD) loss for latent space + regularization. It computes the reconstruction loss for each + target/output layer. Additional losses are computed for other target + variables in the dataset, particularly handling survival analysis if + applicable. All losses are aggregated to compute a total loss, which is + logged and returned. """ dat, y_dict, samples = train_batch - # get input omics modalities and encode them; decode them to output layers + # get input omics modalities and encode them; decode them to output layers x_list_input = [dat[x] for x in self.input_layers] x_hat_list, z, mean, log_var, outputs = self.forward(x_list_input) - + # compute mmd loss for the latent space + reconsruction loss for each target/output layer x_list_output = [dat[x] for x in self.output_layers] - mmd_loss_list = [self.MMD_loss(z.shape[1], z, x_hat_list[i], x_list_output[i]) for i in range(len(self.output_layers))] + mmd_loss_list = [ + self.MMD_loss(z.shape[1], z, x_hat_list[i], x_list_output[i]) + for i in range(len(self.output_layers)) + ] mmd_loss = torch.mean(torch.stack(mmd_loss_list)) - # compute loss values for the supervisor heads - losses = {'mmd_loss': mmd_loss} - + # compute loss values for the supervisor heads + losses = {"mmd_loss": mmd_loss} + for var in self.variables: if var == self.surv_event_var: durations = y_dict[self.surv_time_var] - events = y_dict[self.surv_event_var] - risk_scores = outputs[var] #output of MLP + events = y_dict[self.surv_event_var] + risk_scores = outputs[var] # output of MLP loss = cox_ph_loss(risk_scores, durations, events) - else: + else: y_hat = outputs[var] y = y_dict[var] loss = self.compute_loss(var, y, y_hat) losses[var] = loss - + total_loss = self.compute_total_loss(losses) - # add total loss for logging - losses['train_loss'] = total_loss + # add total loss for logging + losses["train_loss"] = total_loss if log: self.log_dict(losses, on_step=False, on_epoch=True, prog_bar=True) return total_loss - - def validation_step(self, val_batch, batch_idx, log = True): + + def validation_step(self, val_batch, batch_idx, log=True): """ - Executes one validation step using a single batch of data, assessing the model's performance on the validation set. - + Executes one validation step using a single batch of data, assessing + the model's performance on the validation set. + Args: - val_batch (tuple): The batch data containing input features and target labels for validation. + val_batch (tuple): The batch data containing input features and + target labels for validation. batch_idx (int): The index of the current batch in the validation process. log (bool, optional): Indicates whether to log the validation losses during this step. Defaults to True. - + Returns: - torch.Tensor: The total loss for the current validation batch, calculated by combining MMD loss, reconstruction - losses, and losses from supervisor heads for specified target variables. - - In this method, the model processes input data by encoding it through specified input layers and decoding it to - targeted output layers. It computes the Maximum Mean Discrepancy (MMD) loss to measure the divergence between - the model's latent representations and a predefined distribution, along with reconstruction losses for output layers. - Additionally, it calculates losses for other target variables in the dataset, handling complex scenarios like survival - analysis where applicable. The aggregated losses are then summed up to form the total validation loss, which is logged - and returned. + torch.Tensor: The total loss for the current validation batch, + calculated by combining MMD loss, reconstruction losses, and + losses from supervisor heads for specified target variables. + + In this method, the model processes input data by encoding it through + specified input layers and decoding it to targeted output layers. It + computes the Maximum Mean Discrepancy (MMD) loss to measure the + divergence between the model's latent representations and a predefined + distribution, along with reconstruction losses for output layers. + Additionally, it calculates losses for other target variables in the + dataset, handling complex scenarios like survival analysis if applicable. + The aggregated losses are then summed up to form the total validation + loss, which is logged and returned. """ dat, y_dict, samples = val_batch # get input omics modalities and encode them x_list_input = [dat[x] for x in self.input_layers] x_hat_list, z, mean, log_var, outputs = self.forward(x_list_input) - + # compute mmd loss for the latent space + reconsruction loss for each target/output layer x_list_output = [dat[x] for x in self.output_layers] - mmd_loss_list = [self.MMD_loss(z.shape[1], z, x_hat_list[i], x_list_output[i]) for i in range(len(self.output_layers))] + mmd_loss_list = [ + self.MMD_loss(z.shape[1], z, x_hat_list[i], x_list_output[i]) + for i in range(len(self.output_layers)) + ] mmd_loss = torch.mean(torch.stack(mmd_loss_list)) - # compute loss values for the supervisor heads - losses = {'mmd_loss': mmd_loss} + # compute loss values for the supervisor heads + losses = {"mmd_loss": mmd_loss} for var in self.variables: if var == self.surv_event_var: durations = y_dict[self.surv_time_var] - events = y_dict[self.surv_event_var] - risk_scores = outputs[var] #output of MLP + events = y_dict[self.surv_event_var] + risk_scores = outputs[var] # output of MLP loss = cox_ph_loss(risk_scores, durations, events) - else: + else: y_hat = outputs[var] y = y_dict[var] loss = self.compute_loss(var, y, y_hat) losses[var] = loss - + total_loss = sum(losses.values()) - losses['val_loss'] = total_loss + losses["val_loss"] = total_loss if log: self.log_dict(losses, on_step=False, on_epoch=True, prog_bar=True) return total_loss - + def transform(self, dataset): """ Transform the input dataset to latent representation. @@ -361,10 +423,10 @@ def transform(self, dataset): x_list_input = [dataset.dat[x] for x in self.input_layers] M = self.forward(x_list_input)[1].detach().numpy() z = pd.DataFrame(M) - z.columns = [''.join(['E', str(x)]) for x in z.columns] + z.columns = ["".join(["E", str(x)]) for x in z.columns] z.index = dataset.samples return z - + def predict(self, dataset): """ Evaluate the model on a dataset. @@ -376,44 +438,47 @@ def predict(self, dataset): predicted values. """ self.eval() - + x_list_input = [dataset.dat[x] for x in self.input_layers] X_hat, z, mean, log_var, outputs = self.forward(x_list_input) - - predictions = {var: [] for var in self.variables} # Initialize prediction storage + + predictions = { + var: [] for var in self.variables + } # Initialize prediction storage # Collect predictions for each variable for var in self.variables: logits = outputs[var].detach().cpu() # Raw model outputs (logits) - if dataset.variable_types[var] == 'categorical': - probs = torch.softmax(logits, dim=1).numpy() # class probabilities between 0 and 1 + if dataset.variable_types[var] == "categorical": + probs = torch.softmax( + logits, dim=1 + ).numpy() # class probabilities between 0 and 1 predictions[var].extend(probs) else: - predictions[var].extend(logits.numpy()) # return raw output for regression problems - # Convert lists to arrays + predictions[var].extend( + logits.numpy() + ) # return raw output for regression problems + # Convert lists to arrays predictions = {var: np.array(predictions[var]) for var in predictions} - + return predictions - - + def decode(self, dataset): """ - Extract the decoded values of the target/output layers + Extract the decoded values of the target/output layers """ self.eval() x_list_input = [dataset.dat[x] for x in self.input_layers] - x_list_output = [dataset.dat[x] for x in self.output_layers] x_hat_list, z, mean, log_var, outputs = self.forward(x_list_input) X = {} for i in range(len(self.output_layers)): x = pd.DataFrame(x_hat_list[i].detach().numpy()).transpose() layer = self.output_layers[i] x.columns = dataset.samples - x.index = dataset.features[layer] + x.index = dataset.features[layer] X[layer] = x return X - def compute_kernel(self, x, y): """ Compute the Gaussian kernel matrix between two sets of vectors. @@ -428,12 +493,12 @@ def compute_kernel(self, x, y): x_size = x.size(0) y_size = y.size(0) dim = x.size(1) - x = x.unsqueeze(1) # (x_size, 1, dim) - y = y.unsqueeze(0) # (1, y_size, dim) + x = x.unsqueeze(1) # (x_size, 1, dim) + y = y.unsqueeze(0) # (1, y_size, dim) tiled_x = x.expand(x_size, y_size, dim) tiled_y = y.expand(x_size, y_size, dim) - kernel_input = (tiled_x - tiled_y).pow(2).mean(2)/float(dim) - return torch.exp(-kernel_input) # (x_size, y_size) + kernel_input = (tiled_x - tiled_y).pow(2).mean(2) / float(dim) + return torch.exp(-kernel_input) # (x_size, y_size) def compute_mmd(self, x, y): """ @@ -449,7 +514,7 @@ def compute_mmd(self, x, y): x_kernel = self.compute_kernel(x, x) y_kernel = self.compute_kernel(y, y) xy_kernel = self.compute_kernel(x, y) - mmd = x_kernel.mean() + y_kernel.mean() - 2*xy_kernel.mean() + mmd = x_kernel.mean() + y_kernel.mean() - 2 * xy_kernel.mean() return mmd def MMD_loss(self, latent_dim, z, xhat, x): @@ -465,27 +530,37 @@ def MMD_loss(self, latent_dim, z, xhat, x): Returns: torch.Tensor: A scalar tensor representing the MMD loss. """ - true_samples = torch.randn(200, latent_dim, device = self.device) - mmd = self.compute_mmd(true_samples, z) # compute maximum mean discrepancy (MMD) - nll = (xhat - x).pow(2).mean() #negative log likelihood - return mmd+nll - - # Adaptor forward function for captum integrated gradients. + true_samples = torch.randn(200, latent_dim, device=self.device) + mmd = self.compute_mmd( + true_samples, z + ) # compute maximum mean discrepancy (MMD) + nll = (xhat - x).pow(2).mean() # negative log likelihood + return mmd + nll + + # Adaptor forward function for captum integrated gradients. def forward_target(self, *args): input_data = list(args[:-2]) # one or more tensors (one per omics layer) target_var = args[-2] # target variable of interest - steps = args[-1] # number of steps for IntegratedGradients().attribute + steps = args[-1] # number of steps for IntegratedGradients().attribute outputs_list = [] for i in range(steps): # get list of tensors for each step into a list of tensors x_step = [input_data[j][i] for j in range(len(input_data))] x_hat_list, z, mean, log_var, outputs = self.forward(x_step) outputs_list.append(outputs[target_var]) - return torch.cat(outputs_list, dim = 0) - - def compute_feature_importance(self, dataset, target_var, method="IntegratedGradients", steps_or_samples=5, batch_size=64): + return torch.cat(outputs_list, dim=0) + + def compute_feature_importance( + self, + dataset, + target_var, + method="IntegratedGradients", + steps_or_samples=5, + batch_size=64, + ): """ - Computes the feature importance for each variable in the dataset using either Integrated Gradients or Gradient SHAP. + Computes the feature importance for each variable in the dataset + using either Integrated Gradients or Gradient SHAP. Args: dataset: The dataset object containing the features and data. @@ -500,27 +575,41 @@ def compute_feature_importance(self, dataset, target_var, method="IntegratedGrad pd.DataFrame: A DataFrame containing feature importances across different variables and data modalities. """ from ..utils import create_device_from_string, to_device_safe - device = create_device_from_string(self.device_type if hasattr(self, 'device_type') and self.device_type else 'auto') - + + device = create_device_from_string( + self.device_type + if hasattr(self, "device_type") and self.device_type + else "auto" + ) + # Force CPU for Captum feature importance, as MPS lacks the required float64 support. - if device.type == 'mps': - print("[WARNING] MPS device detected. Computing feature importance on CPU due to MPS float64 incompatibility.") - device = torch.device('cpu') - + if device.type == "mps": + print( + "[WARNING] MPS device detected. Computing feature importance on CPU due to MPS float64 incompatibility." + ) + device = torch.device("cpu") + self.to(device) - - print("[INFO] Computing feature importance for variable:",target_var,"on device:",device) - + + print( + "[INFO] Computing feature importance for variable:", + target_var, + "on device:", + device, + ) + # Choose the attribution method dynamically if method == "IntegratedGradients": explainer = IntegratedGradients(self.forward_target) elif method == "GradientShap": explainer = GradientShap(self.forward_target) else: - raise ValueError(f"Unsupported method '{method}'. Choose 'IntegratedGradients' or 'GradientShap'.") + raise ValueError( + f"Unsupported method '{method}'. Choose 'IntegratedGradients' or 'GradientShap'." + ) # Get the number of classes for the target variable - if self.variable_types[target_var] == 'numerical': + if self.variable_types[target_var] == "numerical": num_class = 1 else: num_class = len(np.unique(self.ann[target_var])) @@ -528,49 +617,70 @@ def compute_feature_importance(self, dataset, target_var, method="IntegratedGrad dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False) aggregated_attributions = [[] for _ in range(num_class)] - + for batch in dataloader: dat, _, _ = batch x_list = [to_device_safe(dat[x], device) for x in self.input_layers] input_data = tuple([data.unsqueeze(0).requires_grad_() for data in x_list]) - - if method == 'IntegratedGradients': + + if method == "IntegratedGradients": baseline = tuple(torch.zeros_like(x) for x in input_data) - elif method == 'GradientShap': # provide multiple baselines for Gr.Shap + elif method == "GradientShap": # provide multiple baselines for Gr.Shap baseline = tuple( - torch.cat([torch.zeros_like(x) for _ in range(steps_or_samples)], dim=0) + torch.cat( + [torch.zeros_like(x) for _ in range(steps_or_samples)], + dim=0, + ) for x in input_data ) if num_class == 1: - if method == 'IntegratedGradients': - attributions = explainer.attribute(input_data, baseline, - additional_forward_args=(target_var, steps_or_samples), - n_steps=steps_or_samples) - elif method == 'GradientShap': - attributions = explainer.attribute(input_data, baseline, - additional_forward_args=(target_var, steps_or_samples), - n_samples=steps_or_samples) + if method == "IntegratedGradients": + attributions = explainer.attribute( + input_data, + baseline, + additional_forward_args=(target_var, steps_or_samples), + n_steps=steps_or_samples, + ) + elif method == "GradientShap": + attributions = explainer.attribute( + input_data, + baseline, + additional_forward_args=(target_var, steps_or_samples), + n_samples=steps_or_samples, + ) aggregated_attributions[0].append(attributions) else: for target_class in range(num_class): # returns a tuple of tensors (one per data modality) - if method == 'IntegratedGradients': - attributions = explainer.attribute(input_data, baseline, - additional_forward_args=(target_var, steps_or_samples), - target=target_class, - n_steps=steps_or_samples) - elif method == 'GradientShap': - attributions = explainer.attribute(input_data, baseline, - additional_forward_args=(target_var, steps_or_samples), - target=target_class, - n_samples=steps_or_samples) + if method == "IntegratedGradients": + attributions = explainer.attribute( + input_data, + baseline, + additional_forward_args=( + target_var, + steps_or_samples, + ), + target=target_class, + n_steps=steps_or_samples, + ) + elif method == "GradientShap": + attributions = explainer.attribute( + input_data, + baseline, + additional_forward_args=( + target_var, + steps_or_samples, + ), + target=target_class, + n_samples=steps_or_samples, + ) aggregated_attributions[target_class].append(attributions) - # For each target class and for each data modality/layer, concatenate attributions accross batches + # For each target class and for each data modality/layer, concatenate attributions accross batches layers = list(self.input_layers) num_layers = len(layers) - processed_attributions = [] + processed_attributions = [] # Process each class for class_idx in range(len(aggregated_attributions)): class_attr = aggregated_attributions[class_idx] @@ -583,32 +693,44 @@ def compute_feature_importance(self, dataset, target_var, method="IntegratedGrad attr_concat = torch.cat(layer_tensors, dim=1) layer_attributions.append(attr_concat) processed_attributions.append(layer_attributions) - + # summarize feature importances # Compute absolute attributions # Move the processed tensors to CPU for further operations that are not supported on GPU - abs_attr = [[torch.abs(a).cpu() for a in attr_class] for attr_class in processed_attributions] - # average over samples + abs_attr = [ + [torch.abs(a).cpu() for a in attr_class] + for attr_class in processed_attributions + ] + # average over samples imp = [[a.mean(dim=1) for a in attr_class] for attr_class in abs_attr] # move the model also back to cpu (if not already on cpu) - self.to('cpu') + self.to("cpu") - # combine into a single data frame + # combine into a single data frame df_list = [] for i in range(num_class): for j in range(len(layers)): features = dataset.features[layers[j]] importances = imp[i][j][0].detach().numpy() - target_class_label = dataset.label_mappings[target_var].get(i) if target_var in dataset.label_mappings else '' - df_list.append(pd.DataFrame({'target_variable': target_var, - 'target_class': i, - 'target_class_label': target_class_label, - 'layer': layers[j], - 'name': features, 'importance': importances})) - df_imp = pd.concat(df_list, ignore_index = True) - + target_class_label = ( + dataset.label_mappings[target_var].get(i) + if target_var in dataset.label_mappings + else "" + ) + df_list.append( + pd.DataFrame( + { + "target_variable": target_var, + "target_class": i, + "target_class_label": target_class_label, + "layer": layers[j], + "name": features, + "importance": importances, + } + ) + ) + df_imp = pd.concat(df_list, ignore_index=True) + # save scores in model self.feature_importances[target_var] = df_imp - - diff --git a/flexynesis/models/direct_pred.py b/flexynesis/models/direct_pred.py index 4865994a..d2cd12b7 100644 --- a/flexynesis/models/direct_pred.py +++ b/flexynesis/models/direct_pred.py @@ -1,20 +1,16 @@ +import lightning as pl +import numpy as np +import pandas as pd import torch +from captum.attr import GradientShap, IntegratedGradients from torch import nn from torch.nn import functional as F -import lightning as pl -from torch.utils.data import Dataset, DataLoader, random_split - -import pandas as pd -import numpy as np -import os, argparse -from scipy import stats -from functools import reduce - -from captum.attr import IntegratedGradients, GradientShap +from torch.utils.data import DataLoader -from ..modules import * +from ..modules import MLP, cox_ph_loss from ..utils import to_device_safe + class DirectPred(pl.LightningModule): """ A fully connected network for multi-omics integration with supervisor heads. @@ -30,9 +26,17 @@ class DirectPred(pl.LightningModule): device_type (str, optional): Type of device to run the model ('gpu' or 'cpu'). Defaults to None. """ - def __init__(self, config, dataset, target_variables, batch_variables = None, - surv_event_var = None, surv_time_var = None, use_loss_weighting = True, - device_type = None): + def __init__( + self, + config, + dataset, + target_variables, + batch_variables=None, + surv_event_var=None, + surv_time_var=None, + use_loss_weighting=True, + device_type=None, + ): super(DirectPred, self).__init__() self.config = config self.target_variables = target_variables @@ -43,11 +47,15 @@ def __init__(self, config, dataset, target_variables, batch_variables = None, if self.surv_event_var is not None and self.surv_time_var is not None: self.target_variables = self.target_variables + [self.surv_event_var] self.batch_variables = batch_variables - self.variables = self.target_variables + batch_variables if batch_variables else self.target_variables + self.variables = ( + self.target_variables + batch_variables + if batch_variables + else self.target_variables + ) self.feature_importances = {} self.use_loss_weighting = use_loss_weighting self.device_type = device_type - + if self.use_loss_weighting: # Initialize log variance parameters for uncertainty weighting self.log_vars = nn.ParameterDict() @@ -57,31 +65,43 @@ def __init__(self, config, dataset, target_variables, batch_variables = None, self.variable_types = dataset.variable_types self.ann = dataset.ann self.layers = list(dataset.dat.keys()) - self.input_dims = [len(dataset.features[self.layers[i]]) for i in range(len(self.layers))] - - self.encoders = nn.ModuleList([ - MLP(input_dim=self.input_dims[i], - # define hidden_dim size relative to the input_dim size - hidden_dim=int(self.input_dims[i] * self.config['hidden_dim_factor']), - output_dim=self.config['latent_dim']) for i in range(len(self.layers))]) + self.input_dims = [ + len(dataset.features[self.layers[i]]) for i in range(len(self.layers)) + ] + + self.encoders = nn.ModuleList( + [ + MLP( + input_dim=self.input_dims[i], + # define hidden_dim size relative to the input_dim size + hidden_dim=int( + self.input_dims[i] * self.config["hidden_dim_factor"] + ), + output_dim=self.config["latent_dim"], + ) + for i in range(len(self.layers)) + ] + ) if len(self.input_dims) > 1: self.fusion_block = nn.Linear( - in_features=self.config['latent_dim'] * len(self.layers), - out_features=self.config['latent_dim'] + in_features=self.config["latent_dim"] * len(self.layers), + out_features=self.config["latent_dim"], ) else: self.fusion_block = None - + self.MLPs = nn.ModuleDict() # using ModuleDict to store multiple MLPs for var in self.variables: - if self.variable_types[var] == 'numerical': + if self.variable_types[var] == "numerical": num_class = 1 else: num_class = len(np.unique(self.ann[var])) - self.MLPs[var] = MLP(input_dim=self.config['latent_dim'], - hidden_dim=self.config['supervisor_hidden_dim'], - output_dim=num_class) + self.MLPs[var] = MLP( + input_dim=self.config["latent_dim"], + hidden_dim=self.config["supervisor_hidden_dim"], + output_dim=num_class, + ) def forward(self, x_list): """ @@ -91,22 +111,26 @@ def forward(self, x_list): x_list (list of torch.Tensor): A list of input matrices (omics layers), one for each layer. Returns: - dict: A dictionary where each key-value pair corresponds to the target variable name and its predicted output respectively. + dict: A dictionary where each key-value pair corresponds to the + target variable name and its predicted output respectively. """ embeddings_list = [] # Process each input matrix with its corresponding Encoder for i, x in enumerate(x_list): embeddings_list.append(self.encoders[i](x)) embeddings_concat = torch.cat(embeddings_list, dim=1) - # if multiple embeddings, fuse them - embeddings = self.fusion_block(embeddings_concat) if self.fusion_block else embeddings_concat + # if multiple embeddings, fuse them + embeddings = ( + self.fusion_block(embeddings_concat) + if self.fusion_block + else embeddings_concat + ) outputs = {} for var, mlp in self.MLPs.items(): outputs[var] = mlp(embeddings) - return outputs - - + return outputs + def configure_optimizers(self): """ Configure the optimizer for the DirectPred model. @@ -115,63 +139,69 @@ def configure_optimizers(self): torch.optim.Optimizer: The configured optimizer. """ - optimizer = torch.optim.Adam(self.parameters(), lr=self.config['lr']) + optimizer = torch.optim.Adam(self.parameters(), lr=self.config["lr"]) return optimizer - + def compute_loss(self, var, y, y_hat): """ Computes the loss for a specific variable based on whether the variable is numerical or categorical. Handles missing labels by excluding them from the loss calculation. - + Args: var (str): The name of the variable for which the loss is being calculated. y (torch.Tensor): The true labels or values for the variable. y_hat (torch.Tensor): The predicted labels or values output by the model. - + Returns: torch.Tensor: The calculated loss tensor for the variable. If there are no valid labels or values to compute the loss (all are missing), returns a zero loss tensor with gradient enabled. - + The method first checks the type of the variable (`var`) from `variable_types`. If the variable is numerical, it computes the mean squared error loss. For categorical variables, it calculates the cross-entropy loss. The method ensures to ignore any instances where the labels are missing (NaN for numerical or -1 for categorical as assumed missing value encoding) when calculating the loss. """ - if self.variable_types[var] == 'numerical': + if self.variable_types[var] == "numerical": # Ignore instances with missing labels for numerical variables valid_indices = ~torch.isnan(y) - if valid_indices.sum() > 0: # only calculate loss if there are valid targets + if ( + valid_indices.sum() > 0 + ): # only calculate loss if there are valid targets y_hat = y_hat[valid_indices] y = y[valid_indices] loss = F.mse_loss(torch.flatten(y_hat), y.float()) else: - loss = torch.tensor(0.0, device=y_hat.device, requires_grad=True) # if no valid labels, set loss to 0 + loss = torch.tensor( + 0.0, device=y_hat.device, requires_grad=True + ) # if no valid labels, set loss to 0 else: # Ignore instances with missing labels for categorical variables # Assuming that missing values were encoded as -1 valid_indices = (y != -1) & (~torch.isnan(y)) - if valid_indices.sum() > 0: # only calculate loss if there are valid targets + if ( + valid_indices.sum() > 0 + ): # only calculate loss if there are valid targets y_hat = y_hat[valid_indices] y = y[valid_indices] loss = F.cross_entropy(y_hat, y.long()) - else: + else: loss = torch.tensor(0.0, device=y_hat.device, requires_grad=True) return loss - + def compute_total_loss(self, losses): """ Computes the total loss from a dictionary of individual losses. This method can compute either weighted or unweighted total loss based on the model configuration. If loss weighting is enabled and there are multiple loss components, it uses uncertainty-based weighting. See Kendall A. et al, https://arxiv.org/abs/1705.07115. - + Args: losses (dict of torch.Tensor): A dictionary where each key is a variable name and each value is the loss tensor associated with that variable. - + Returns: torch.Tensor: The total loss computed across all inputs, either weighted or unweighted. - + The method checks if loss weighting is used (`use_loss_weighting`) and if there are multiple losses to weight. If so, it computes the weighted sum of losses, where the weight involves the exponential of the negative log variance (acting as precision) associated with each loss, @@ -180,15 +210,18 @@ def compute_total_loss(self, losses): loss component, it sums up the losses directly. """ if self.use_loss_weighting and len(losses) > 1: - # Compute weighted loss for each loss + # Compute weighted loss for each loss # Weighted loss = precision * loss + log-variance - total_loss = sum(torch.exp(-self.log_vars[name]) * loss + self.log_vars[name] for name, loss in losses.items()) + total_loss = sum( + torch.exp(-self.log_vars[name]) * loss + self.log_vars[name] + for name, loss in losses.items() + ) else: # Compute unweighted total loss total_loss = sum(losses.values()) return total_loss - def training_step(self, train_batch, batch_idx, log = True): + def training_step(self, train_batch, batch_idx, log=True): """ Executes one training step using a single batch from the training dataset. @@ -200,8 +233,8 @@ def training_step(self, train_batch, batch_idx, log = True): Returns: torch.Tensor: The total loss computed for the batch. """ - - dat, y_dict, samples = train_batch + + dat, y_dict, samples = train_batch layers = dat.keys() x_list = [dat[x] for x in layers] outputs = self.forward(x_list) @@ -209,23 +242,23 @@ def training_step(self, train_batch, batch_idx, log = True): for var in self.variables: if var == self.surv_event_var: durations = y_dict[self.surv_time_var] - events = y_dict[self.surv_event_var] - risk_scores = outputs[var] #output of MLP + events = y_dict[self.surv_event_var] + risk_scores = outputs[var] # output of MLP loss = cox_ph_loss(risk_scores, durations, events) else: y_hat = outputs[var] y = y_dict[var] loss = self.compute_loss(var, y, y_hat) losses[var] = loss - + total_loss = self.compute_total_loss(losses) # add train loss for logging - losses['train_loss'] = total_loss + losses["train_loss"] = total_loss if log: self.log_dict(losses, on_step=False, on_epoch=True, prog_bar=True) return total_loss - - def validation_step(self, val_batch, batch_idx, log = True): + + def validation_step(self, val_batch, batch_idx, log=True): """ Executes one validation step using a single batch from the validation dataset. @@ -237,7 +270,7 @@ def validation_step(self, val_batch, batch_idx, log = True): Returns: torch.Tensor: The total loss computed for the batch. """ - dat, y_dict, samples = val_batch + dat, y_dict, samples = val_batch layers = dat.keys() x_list = [dat[x] for x in layers] outputs = self.forward(x_list) @@ -245,8 +278,8 @@ def validation_step(self, val_batch, batch_idx, log = True): for var in self.variables: if var == self.surv_event_var: durations = y_dict[self.surv_time_var] - events = y_dict[self.surv_event_var] - risk_scores = outputs[var] #output of MLP + events = y_dict[self.surv_event_var] + risk_scores = outputs[var] # output of MLP loss = cox_ph_loss(risk_scores, durations, events) else: y_hat = outputs[var] @@ -254,7 +287,7 @@ def validation_step(self, val_batch, batch_idx, log = True): loss = self.compute_loss(var, y, y_hat) losses[var] = loss total_loss = sum(losses.values()) - losses['val_loss'] = total_loss + losses["val_loss"] = total_loss if log: self.log_dict(losses, on_step=False, on_epoch=True, prog_bar=True) return total_loss @@ -271,18 +304,29 @@ def predict(self, dataset): """ self.eval() # Set the model to evaluation mode from ..utils import create_device_from_string - device = create_device_from_string(self.device_type if hasattr(self, 'device_type') and self.device_type else 'auto') + + device = create_device_from_string( + self.device_type + if hasattr(self, "device_type") and self.device_type + else "auto" + ) self.to(device) # Move the model to the appropriate device # Create a DataLoader with a practical batch size - dataloader = DataLoader(dataset, batch_size=64, shuffle=False) # Adjust the batch size as needed + dataloader = DataLoader( + dataset, batch_size=64, shuffle=False + ) # Adjust the batch size as needed - predictions = {var: [] for var in self.variables} # Initialize prediction storage + predictions = { + var: [] for var in self.variables + } # Initialize prediction storage # Process each batch for batch in dataloader: dat, y_dict, samples = batch - x_list = [to_device_safe(dat[x], device) for x in dat.keys()] # Prepare the data batch for processing + x_list = [ + to_device_safe(dat[x], device) for x in dat.keys() + ] # Prepare the data batch for processing # Perform the forward pass outputs = self.forward(x_list) @@ -290,17 +334,21 @@ def predict(self, dataset): # Collect predictions for each variable for var in self.variables: logits = outputs[var].detach().cpu() # Raw model outputs (logits) - - if dataset.variable_types[var] == 'categorical': - probs = torch.softmax(logits, dim=1).numpy() # class probabilities between 0 and 1 + + if dataset.variable_types[var] == "categorical": + probs = torch.softmax( + logits, dim=1 + ).numpy() # class probabilities between 0 and 1 predictions[var].extend(probs) else: - predictions[var].extend(logits.numpy()) # return raw output for regression problems - # Convert lists to arrays + predictions[var].extend( + logits.numpy() + ) # return raw output for regression problems + # Convert lists to arrays predictions = {var: np.array(predictions[var]) for var in predictions} return predictions - + def transform(self, dataset): """ Transforms the input data into a lower-dimensional representation using trained encoders. @@ -313,10 +361,17 @@ def transform(self, dataset): """ self.eval() # Set the model to evaluation mode from ..utils import create_device_from_string - device = create_device_from_string(self.device_type if hasattr(self, 'device_type') and self.device_type else 'auto') + + device = create_device_from_string( + self.device_type + if hasattr(self, "device_type") and self.device_type + else "auto" + ) self.to(device) # Move the model to the appropriate device - dataloader = DataLoader(dataset, batch_size=64, shuffle=False) # Adjust the batch size as needed + dataloader = DataLoader( + dataset, batch_size=64, shuffle=False + ) # Adjust the batch size as needed embeddings_list = [] # Initialize a list to collect all batch embeddings sample_names = [] # List to collect sample names @@ -328,42 +383,62 @@ def transform(self, dataset): # Process each input matrix with its corresponding Encoder for i, x in enumerate(dat.values()): x = to_device_safe(x, device) # Move data to GPU - encoded_x = self.encoders[i](x) # Transform data using the corresponding encoder + encoded_x = self.encoders[i]( + x + ) # Transform data using the corresponding encoder batch_embeddings.append(encoded_x) - + # Concatenate all embeddings from the current batch embeddings_batch_concat = torch.cat(batch_embeddings, dim=1) - # if multiple embeddings, fuse them - embeddings_batch = self.fusion_block(embeddings_batch_concat) if self.fusion_block else embeddings_batch_concat + # if multiple embeddings, fuse them + embeddings_batch = ( + self.fusion_block(embeddings_batch_concat) + if self.fusion_block + else embeddings_batch_concat + ) - embeddings_list.append(embeddings_batch.detach().cpu()) # Move tensor back to CPU and detach + embeddings_list.append( + embeddings_batch.detach().cpu() + ) # Move tensor back to CPU and detach sample_names.extend(samples) # Collect sample names # Concatenate all batch embeddings into one tensor embeddings_concat = torch.cat(embeddings_list, dim=0) # Converting tensor to numpy array and then to DataFrame - embeddings_df = pd.DataFrame(embeddings_concat.numpy(), - index=sample_names, # Set DataFrame index to sample names - columns=[f"E{dim}" for dim in range(embeddings_concat.shape[1])]) + embeddings_df = pd.DataFrame( + embeddings_concat.numpy(), + index=sample_names, # Set DataFrame index to sample names + columns=[f"E{dim}" for dim in range(embeddings_concat.shape[1])], + ) return embeddings_df - - # Adaptor forward function for captum integrated gradients or gradient shap + + # Adaptor forward function for captum integrated gradients or gradient shap def forward_target(self, *args): input_data = list(args[:-2]) # one or more tensors (one per omics layer) target_var = args[-2] # target variable of interest - steps = args[-1] # number of steps/samples for IntegratedGradients().attribute or GradientShap.attribute + steps = args[ + -1 + ] # number of steps/samples for IntegratedGradients().attribute or GradientShap.attribute outputs_list = [] for i in range(steps): # get list of tensors for each step into a list of tensors x_step = [input_data[j][i] for j in range(len(input_data))] out = self.forward(x_step) outputs_list.append(out[target_var]) - return torch.cat(outputs_list, dim = 0) - - def compute_feature_importance(self, dataset, target_var, method="IntegratedGradients", steps_or_samples=5, batch_size=64): + return torch.cat(outputs_list, dim=0) + + def compute_feature_importance( + self, + dataset, + target_var, + method="IntegratedGradients", + steps_or_samples=5, + batch_size=64, + ): """ - Computes the feature importance for each variable in the dataset using either Integrated Gradients or Gradient SHAP. + Computes feature importance for each variable in the dataset using either + Integrated Gradients or Gradient SHAP. Args: dataset: The dataset object containing the features and data. @@ -378,13 +453,20 @@ def compute_feature_importance(self, dataset, target_var, method="IntegratedGrad pd.DataFrame: A DataFrame containing feature importances across different variables and data modalities. """ from ..utils import create_device_from_string, to_device_safe - device = create_device_from_string(self.device_type if hasattr(self, 'device_type') and self.device_type else 'auto') - + + device = create_device_from_string( + self.device_type + if hasattr(self, "device_type") and self.device_type + else "auto" + ) + # Force CPU for Captum feature importance, as MPS lacks the required float64 support. - if device.type == 'mps': - print("[WARNING] MPS device detected. Computing feature importance on CPU due to MPS float64 incompatibility.") - device = torch.device('cpu') - + if device.type == "mps": + print( + "[WARNING] MPS device detected. Computing feature importance on CPU due to MPS float64 incompatibility." + ) + device = torch.device("cpu") + self.to(device) dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False) @@ -395,10 +477,12 @@ def compute_feature_importance(self, dataset, target_var, method="IntegratedGrad elif method == "GradientShap": explainer = GradientShap(self.forward_target) else: - raise ValueError(f"Unsupported method '{method}'. Choose 'IntegratedGradients' or 'GradientShap'.") + raise ValueError( + f"Unsupported method '{method}'. Choose 'IntegratedGradients' or 'GradientShap'." + ) # Handle target class (numerical vs categorical) - if dataset.variable_types[target_var] == 'numerical': + if dataset.variable_types[target_var] == "numerical": num_class = 1 else: num_class = len(np.unique([y[target_var] for _, y, _ in dataset])) @@ -408,38 +492,59 @@ def compute_feature_importance(self, dataset, target_var, method="IntegratedGrad dat, _, _ = batch x_list = [to_device_safe(dat[x], device) for x in dat.keys()] input_data = tuple([data.unsqueeze(0).requires_grad_() for data in x_list]) - - if method == 'IntegratedGradients': + + if method == "IntegratedGradients": baseline = tuple(torch.zeros_like(x) for x in input_data) - elif method == 'GradientShap': # provide multiple baselines for Gr.Shap + elif method == "GradientShap": # provide multiple baselines for Gr.Shap baseline = tuple( - torch.cat([torch.zeros_like(x) for _ in range(steps_or_samples)], dim=0) + torch.cat( + [torch.zeros_like(x) for _ in range(steps_or_samples)], + dim=0, + ) for x in input_data ) if num_class == 1: # returns a tuple of tensors (one per data modality) - if method == 'IntegratedGradients': - attributions = explainer.attribute(input_data, baseline, - additional_forward_args=(target_var, steps_or_samples), - n_steps=steps_or_samples) - elif method == 'GradientShap': - attributions = explainer.attribute(input_data, baseline, - additional_forward_args=(target_var, steps_or_samples), - n_samples=steps_or_samples) + if method == "IntegratedGradients": + attributions = explainer.attribute( + input_data, + baseline, + additional_forward_args=(target_var, steps_or_samples), + n_steps=steps_or_samples, + ) + elif method == "GradientShap": + attributions = explainer.attribute( + input_data, + baseline, + additional_forward_args=(target_var, steps_or_samples), + n_samples=steps_or_samples, + ) aggregated_attributions[0].append(attributions) else: for target_class in range(num_class): # returns a tuple of tensors (one per data modality) - if method == 'IntegratedGradients': - attributions = explainer.attribute(input_data, baseline, - additional_forward_args=(target_var, steps_or_samples), - target=target_class, - n_steps=steps_or_samples) - elif method == 'GradientShap': - attributions = explainer.attribute(input_data, baseline, - additional_forward_args=(target_var, steps_or_samples), - target=target_class, - n_samples=steps_or_samples) + if method == "IntegratedGradients": + attributions = explainer.attribute( + input_data, + baseline, + additional_forward_args=( + target_var, + steps_or_samples, + ), + target=target_class, + n_steps=steps_or_samples, + ) + elif method == "GradientShap": + attributions = explainer.attribute( + input_data, + baseline, + additional_forward_args=( + target_var, + steps_or_samples, + ), + target=target_class, + n_samples=steps_or_samples, + ) aggregated_attributions[target_class].append(attributions) # Post-process attributions layers = list(dataset.dat.keys()) @@ -454,9 +559,12 @@ def compute_feature_importance(self, dataset, target_var, method="IntegratedGrad layer_attributions.append(attr_concat) processed_attributions.append(layer_attributions) - abs_attr = [[torch.abs(a).cpu() for a in attr_class] for attr_class in processed_attributions] + abs_attr = [ + [torch.abs(a).cpu() for a in attr_class] + for attr_class in processed_attributions + ] imp = [[a.mean(dim=1) for a in attr_class] for attr_class in abs_attr] - self.to('cpu') + self.to("cpu") # Combine results into a DataFrame df_list = [] @@ -464,13 +572,22 @@ def compute_feature_importance(self, dataset, target_var, method="IntegratedGrad for j in range(len(layers)): features = dataset.features[layers[j]] importances = imp[i][j][0].detach().numpy() - target_class_label = dataset.label_mappings[target_var].get(i) if target_var in dataset.label_mappings else '' - df_list.append(pd.DataFrame({'target_variable': target_var, - 'target_class': i, - 'target_class_label': target_class_label, - 'layer': layers[j], - 'name': features, - 'importance': importances})) + target_class_label = ( + dataset.label_mappings[target_var].get(i) + if target_var in dataset.label_mappings + else "" + ) + df_list.append( + pd.DataFrame( + { + "target_variable": target_var, + "target_class": i, + "target_class_label": target_class_label, + "layer": layers[j], + "name": features, + "importance": importances, + } + ) + ) df_imp = pd.concat(df_list, ignore_index=True) self.feature_importances[target_var] = df_imp - diff --git a/flexynesis/models/gnn_early.py b/flexynesis/models/gnn_early.py index fb4be5f8..32b1fee1 100644 --- a/flexynesis/models/gnn_early.py +++ b/flexynesis/models/gnn_early.py @@ -1,19 +1,14 @@ +import lightning as pl import numpy as np import pandas as pd - import torch +from captum.attr import GradientShap, IntegratedGradients from torch import nn from torch.nn import functional as F -from torch.utils.data import random_split - -import lightning as pl - from torch.utils.data import DataLoader -from captum.attr import IntegratedGradients, GradientShap - -from ..utils import to_device_safe from ..modules import MLP, cox_ph_loss, flexGCN +from ..utils import to_device_safe class GNN(pl.LightningModule): @@ -51,20 +46,22 @@ class GNN(pl.LightningModule): surv_event_var (str, optional): The variable name representing survival events. surv_time_var (str, optional): The variable name representing survival times. use_loss_weighting (bool, optional): Whether to use uncertainty weighting in loss calculation. Defaults to True. - device_type (str, optional): Specifies the computation device ('gpu' or 'cpu'). Default is None, which uses 'cpu' if 'gpu' is not available. + device_type (str, optional): Specifies the computation device ('gpu' or 'cpu'). Default is None, which uses + 'cpu' if 'gpu' is not available. gnn_conv_type (str, optional): Specifies the type of graph convolutional layer to use. - """ + """ + def __init__( self, config, - dataset, # MultiomicDatasetNW object + dataset, # MultiomicDatasetNW object target_variables, batch_variables=None, surv_event_var=None, surv_time_var=None, use_loss_weighting=True, - device_type = None, - gnn_conv_type = None + device_type=None, + gnn_conv_type=None, ): super().__init__() self.config = config @@ -76,37 +73,57 @@ def __init__( if self.surv_event_var is not None and self.surv_time_var is not None: self.target_variables = self.target_variables + [self.surv_event_var] self.batch_variables = batch_variables - self.variables = self.target_variables + self.batch_variables if self.batch_variables else self.target_variables - self.variable_types = getattr(dataset, 'multiomic_dataset', dataset).variable_types - self.ann = getattr(dataset, 'multiomic_dataset', dataset).ann + self.variables = ( + self.target_variables + self.batch_variables + if self.batch_variables + else self.target_variables + ) + self.variable_types = getattr( + dataset, "multiomic_dataset", dataset + ).variable_types + self.ann = getattr(dataset, "multiomic_dataset", dataset).ann self.edge_index = dataset.edge_index - + self.feature_importances = {} self.use_loss_weighting = use_loss_weighting - self.device_type = device_type + self.device_type = device_type self.gnn_conv_type = gnn_conv_type - + from ..utils import create_device_from_string - device = create_device_from_string(self.device_type if hasattr(self, 'device_type') and self.device_type else 'auto') - self.edge_index = to_device_safe(self.edge_index, device) # edge index is re-used across samples, so we keep it in device - + + device = create_device_from_string( + self.device_type + if hasattr(self, "device_type") and self.device_type + else "auto" + ) + self.edge_index = to_device_safe( + self.edge_index, device + ) # edge index is re-used across samples, so we keep it in device + if self.use_loss_weighting: # Initialize log variance parameters for uncertainty weighting self.log_vars = nn.ParameterDict() for var in self.variables: self.log_vars[var] = nn.Parameter(torch.zeros(1)) - - self.encoders = nn.ModuleList([ - flexGCN( - node_count = dataset[0][0].shape[0], #number of nodes - node_feature_count= dataset[0][0].shape[1], # number of node features - node_embedding_dim=int(self.config["node_embedding_dim"]), - num_convs = int(self.config['num_convs']), # Number of convolutional layers - output_dim=self.config["latent_dim"], - act = self.config['activation'], - conv = self.gnn_conv_type - )]) + + self.encoders = nn.ModuleList( + [ + flexGCN( + node_count=dataset[0][0].shape[0], # number of nodes + node_feature_count=dataset[0][0].shape[ + 1 + ], # number of node features + node_embedding_dim=int(self.config["node_embedding_dim"]), + num_convs=int( + self.config["num_convs"] + ), # Number of convolutional layers + output_dim=self.config["latent_dim"], + act=self.config["activation"], + conv=self.gnn_conv_type, + ) + ] + ) # Init output layers self.MLPs = nn.ModuleDict() @@ -118,10 +135,10 @@ def __init__( self.MLPs[var] = MLP( input_dim=self.config["latent_dim"], hidden_dim=self.config["supervisor_hidden_dim"], - output_dim=num_class + output_dim=num_class, ) - - def forward(self, x, edge_index): + + def forward(self, x, edge_index): """ Defines the forward pass of the GNN. @@ -139,13 +156,13 @@ def forward(self, x, edge_index): outputs[var] = mlp(embeddings) return outputs - - def training_step(self, batch, batch_idx, log = True): + def training_step(self, batch, batch_idx, log=True): """ Performs a training step including loss calculation and logging. Args: - batch (tuple): A batch of data consisting of features, target labels as a dictionary of tensors, and sample ids. + batch (tuple): A batch of data consisting of features, target labels + as a dictionary of tensors, and sample ids. Returns: float: Total loss for the batch. @@ -169,15 +186,22 @@ def training_step(self, batch, batch_idx, log = True): total_loss = self.compute_total_loss(losses) losses["train_loss"] = total_loss if log: - self.log_dict(losses, on_step=False, on_epoch=True, prog_bar=True, batch_size=len(batch)) + self.log_dict( + losses, + on_step=False, + on_epoch=True, + prog_bar=True, + batch_size=len(batch), + ) return total_loss - def validation_step(self, batch, batch_idx, log = True): + def validation_step(self, batch, batch_idx, log=True): """ Performs a validation step, computing losses for a batch of data. Args: - batch (tuple): A batch of data consisting of features, target labels as a dictionary of tensors, and sample ids. + batch (tuple): A batch of data consisting of features, target labels as a + dictionary of tensors, and sample ids. Returns: float: Total validation loss for the batch. @@ -200,9 +224,15 @@ def validation_step(self, batch, batch_idx, log = True): total_loss = sum(losses.values()) losses["val_loss"] = total_loss if log: - self.log_dict(losses, on_step=False, on_epoch=True, prog_bar=True, batch_size=len(batch)) + self.log_dict( + losses, + on_step=False, + on_epoch=True, + prog_bar=True, + batch_size=len(batch), + ) return total_loss - + def configure_optimizers(self): """ Configure the optimizer for the DirectPred model. @@ -210,23 +240,23 @@ def configure_optimizers(self): Returns: torch.optim.Optimizer: The configured optimizer. """ - optimizer = torch.optim.Adam(self.parameters(), lr=self.config['lr']) + optimizer = torch.optim.Adam(self.parameters(), lr=self.config["lr"]) return optimizer def compute_loss(self, var, y, y_hat): """ Computes the loss for a specific variable based on whether the variable is numerical or categorical. Handles missing labels by excluding them from the loss calculation. - + Args: var (str): The name of the variable for which the loss is being calculated. y (torch.Tensor): The true labels or values for the variable. y_hat (torch.Tensor): The predicted labels or values output by the model. - + Returns: torch.Tensor: The calculated loss tensor for the variable. If there are no valid labels or values to compute the loss (all are missing), returns a zero loss tensor with gradient enabled. - + The method first checks the type of the variable (`var`) from `variable_types`. If the variable is numerical, it computes the mean squared error loss. For categorical variables, it calculates the cross-entropy loss. The method ensures to ignore any instances where the labels are missing (NaN for @@ -235,17 +265,23 @@ def compute_loss(self, var, y, y_hat): if self.variable_types[var] == "numerical": # Ignore instances with missing labels for numerical variables valid_indices = ~torch.isnan(y) - if valid_indices.sum() > 0: # only calculate loss if there are valid targets + if ( + valid_indices.sum() > 0 + ): # only calculate loss if there are valid targets y_hat = y_hat[valid_indices] y = y[valid_indices] loss = F.mse_loss(torch.flatten(y_hat), y.float()) else: - loss = torch.tensor(0.0, device=y_hat.device, requires_grad=True) # if no valid labels, set loss to 0 + loss = torch.tensor( + 0.0, device=y_hat.device, requires_grad=True + ) # if no valid labels, set loss to 0 else: # Ignore instances with missing labels for categorical variables # Assuming that missing values were encoded as -1 valid_indices = (y != -1) & (~torch.isnan(y)) - if valid_indices.sum() > 0: # only calculate loss if there are valid targets + if ( + valid_indices.sum() > 0 + ): # only calculate loss if there are valid targets y_hat = y_hat[valid_indices] y = y[valid_indices] loss = F.cross_entropy(y_hat, y.long()) @@ -259,14 +295,14 @@ def compute_total_loss(self, losses): either weighted or unweighted total loss based on the model configuration. If loss weighting is enabled and there are multiple loss components, it uses uncertainty-based weighting. See Kendall A. et al, https://arxiv.org/abs/1705.07115. - + Args: losses (dict of torch.Tensor): A dictionary where each key is a variable name and each value is the loss tensor associated with that variable. - + Returns: torch.Tensor: The total loss computed across all inputs, either weighted or unweighted. - + The method checks if loss weighting is used (`use_loss_weighting`) and if there are multiple losses to weight. If so, it computes the weighted sum of losses, where the weight involves the exponential of the negative log variance (acting as precision) associated with each loss, @@ -278,13 +314,14 @@ def compute_total_loss(self, losses): # Compute weighted loss for each loss # Weighted loss = precision * loss + log-variance total_loss = sum( - torch.exp(-self.log_vars[name]) * loss + self.log_vars[name] for name, loss in losses.items() + torch.exp(-self.log_vars[name]) * loss + self.log_vars[name] + for name, loss in losses.items() ) else: # Compute unweighted total loss total_loss = sum(losses.values()) return total_loss - + def predict(self, dataset): """ Make predictions on an entire dataset. @@ -297,29 +334,42 @@ def predict(self, dataset): """ self.eval() # Set the model to evaluation mode from ..utils import create_device_from_string - device = create_device_from_string(self.device_type if hasattr(self, 'device_type') and self.device_type else 'auto') + + device = create_device_from_string( + self.device_type + if hasattr(self, "device_type") and self.device_type + else "auto" + ) self.to(device) # Move the model to the appropriate device # Create a DataLoader with a practical batch size - dataloader = DataLoader(dataset, batch_size=64, shuffle=False) # Adjust the batch size as needed + dataloader = DataLoader( + dataset, batch_size=64, shuffle=False + ) # Adjust the batch size as needed edge_index = dataset.edge_index.to(device) # Move edge_index to GPU - predictions = {var: [] for var in self.variables} # Initialize prediction storage + predictions = { + var: [] for var in self.variables + } # Initialize prediction storage # Process each batch - for x, y_dict,samples in dataloader: + for x, y_dict, samples in dataloader: x = x.to(device) # Move data to GPU outputs = self.forward(x, edge_index) for var in self.variables: logits = outputs[var].detach().cpu() # Raw model outputs (logits) - if dataset.variable_types[var] == 'categorical': - probs = torch.softmax(logits, dim=1).numpy() # class probabilities between 0 and 1 + if dataset.variable_types[var] == "categorical": + probs = torch.softmax( + logits, dim=1 + ).numpy() # class probabilities between 0 and 1 predictions[var].extend(probs) else: - predictions[var].extend(logits.numpy()) # return raw output for regression problems - # Convert lists to arrays + predictions[var].extend( + logits.numpy() + ) # return raw output for regression problems + # Convert lists to arrays predictions = {var: np.array(predictions[var]) for var in predictions} return predictions @@ -336,11 +386,18 @@ def transform(self, dataset): """ self.eval() # Set the model to evaluation mode from ..utils import create_device_from_string - device = create_device_from_string(self.device_type if hasattr(self, 'device_type') and self.device_type else 'auto') + + device = create_device_from_string( + self.device_type + if hasattr(self, "device_type") and self.device_type + else "auto" + ) self.to(device) # Move the model to the appropriate device edge_index = dataset.edge_index.to(device) # Move edge_index to GPU - dataloader = DataLoader(dataset, batch_size=64, shuffle=False) # Adjust the batch size as needed + dataloader = DataLoader( + dataset, batch_size=64, shuffle=False + ) # Adjust the batch size as needed all_embeddings = [] # List to store embeddings from all batches sample_ids = [] # List to store indices for all samples processed @@ -349,10 +406,12 @@ def transform(self, dataset): x = x.to(device) # Move data to GPU # notice we are using the first encoder (it is currently a early fusion method) - embeddings = self.encoders[0](x, edge_index).detach().cpu().numpy() # Compute embeddings and move to CPU + embeddings = ( + self.encoders[0](x, edge_index).detach().cpu().numpy() + ) # Compute embeddings and move to CPU all_embeddings.append(embeddings) - sample_ids.extend(samples) - + sample_ids.extend(samples) + # Concatenate all embeddings into a single numpy array all_embeddings = np.vstack(all_embeddings) @@ -363,24 +422,31 @@ def transform(self, dataset): columns=[f"E{dim}" for dim in range(all_embeddings.shape[1])], ) return embeddings_df - - # Adaptor forward function for captum integrated gradients. + + # Adaptor forward function for captum integrated gradients. def forward_target(self, *args): input_data = list(args[:-2]) # expect a single tensor (early integration) target_var = args[-2] # target variable of interest - steps = args[-1] # number of steps for IntegratedGradients().attribute + steps = args[-1] # number of steps for IntegratedGradients().attribute outputs_list = [] for i in range(steps): - x_step = input_data[0][i] - #edges_step = edge_index[i] # although, identical, they get copied. + x_step = input_data[0][i] + # edges_step = edge_index[i] # although, identical, they get copied. out = self.forward(x_step, self.dataset_edge_index) outputs_list.append(out[target_var]) - return torch.cat(outputs_list, dim = 0) + return torch.cat(outputs_list, dim=0) - - def compute_feature_importance(self, dataset, target_var, method="IntegratedGradients", steps_or_samples=5, batch_size=64): + def compute_feature_importance( + self, + dataset, + target_var, + method="IntegratedGradients", + steps_or_samples=5, + batch_size=64, + ): """ - Computes the feature importance for each variable in the dataset using either Integrated Gradients or Gradient SHAP. + Computes the feature importance for each variable in the dataset + using either Integrated Gradients or Gradient SHAP. Args: dataset: The dataset object containing the features and data. @@ -394,43 +460,58 @@ def compute_feature_importance(self, dataset, target_var, method="IntegratedGrad Returns: pd.DataFrame: A DataFrame containing feature importances across different variables and data modalities. """ + def bytes_to_gb(bytes): - return bytes / 1024 ** 2 + return bytes / 1024**3 + from ..utils import create_device_from_string, to_device_safe - device = create_device_from_string(self.device_type if hasattr(self, 'device_type') and self.device_type else 'auto') - + + device = create_device_from_string( + self.device_type + if hasattr(self, "device_type") and self.device_type + else "auto" + ) + # Force CPU for Captum feature importance, as MPS lacks the required float64 support. - if device.type == 'mps': - print("[WARNING] MPS device detected. Computing feature importance on CPU due to MPS float64 incompatibility.") - device = torch.device('cpu') - + if device.type == "mps": + print( + "[WARNING] MPS device detected. Computing feature importance on CPU due to MPS float64 incompatibility." + ) + device = torch.device("cpu") + self.to(device) self.dataset_edge_index = to_device_safe(dataset.edge_index, device) dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False) - + # Choose the attribution method dynamically if method == "IntegratedGradients": explainer = IntegratedGradients(self.forward_target) elif method == "GradientShap": explainer = GradientShap(self.forward_target) else: - raise ValueError(f"Unsupported method '{method}'. Choose 'IntegratedGradients' or 'GradientShap'.") + raise ValueError( + f"Unsupported method '{method}'. Choose 'IntegratedGradients' or 'GradientShap'." + ) - if dataset.variable_types[target_var] == 'numerical': + if dataset.variable_types[target_var] == "numerical": num_class = 1 else: num_class = len(np.unique(dataset.ann[target_var])) - + # Report memory usage based on device type - if device.type == 'cuda': - print("Memory before batch processing: {:.3f} MB".format(bytes_to_gb(torch.cuda.max_memory_reserved()))) + if device.type == "cuda": + print( + "Memory before batch processing: {:.3f} MB".format( + bytes_to_gb(torch.cuda.max_memory_reserved()) + ) + ) # MPS and CPU don't have detailed memory tracking like CUDA aggregated_attributions = [[] for _ in range(num_class)] for batch in dataloader: x, y_dict, samples = batch - + # Ensure input data is on the correct device x = to_device_safe(x, device) input_data = x.unsqueeze(0).requires_grad_() @@ -439,38 +520,59 @@ def bytes_to_gb(bytes): x = to_device_safe(x, device) input_data = x.unsqueeze(0).requires_grad_() baseline = torch.zeros_like(input_data) - - if method == 'IntegratedGradients': + + if method == "IntegratedGradients": baseline = torch.zeros_like(input_data) - elif method == 'GradientShap': # provide multiple baselines for Gr.Shap - baseline = torch.cat([torch.zeros_like(input_data) for _ in range(steps_or_samples)], dim=0) - + elif method == "GradientShap": # provide multiple baselines for Gr.Shap + baseline = torch.cat( + [torch.zeros_like(input_data) for _ in range(steps_or_samples)], + dim=0, + ) + if num_class == 1: # returns a tuple of tensors (one per data modality) - if method == 'IntegratedGradients': - attributions = explainer.attribute(input_data, baseline, - additional_forward_args=(target_var, steps_or_samples), - n_steps=steps_or_samples) - elif method == 'GradientShap': - attributions = explainer.attribute(input_data, baseline, - additional_forward_args=(target_var, steps_or_samples), - n_samples=steps_or_samples) + if method == "IntegratedGradients": + attributions = explainer.attribute( + input_data, + baseline, + additional_forward_args=(target_var, steps_or_samples), + n_steps=steps_or_samples, + ) + elif method == "GradientShap": + attributions = explainer.attribute( + input_data, + baseline, + additional_forward_args=(target_var, steps_or_samples), + n_samples=steps_or_samples, + ) aggregated_attributions[0].append(attributions) else: for target_class in range(num_class): - if method == 'IntegratedGradients': - attributions = explainer.attribute(input_data, baseline, - additional_forward_args=(target_var, steps_or_samples), - target=target_class, - n_steps=steps_or_samples) - elif method == 'GradientShap': - attributions = explainer.attribute(input_data, baseline, - additional_forward_args=(target_var, steps_or_samples), - target=target_class, - n_samples=steps_or_samples) + if method == "IntegratedGradients": + attributions = explainer.attribute( + input_data, + baseline, + additional_forward_args=( + target_var, + steps_or_samples, + ), + target=target_class, + n_steps=steps_or_samples, + ) + elif method == "GradientShap": + attributions = explainer.attribute( + input_data, + baseline, + additional_forward_args=( + target_var, + steps_or_samples, + ), + target=target_class, + n_samples=steps_or_samples, + ) aggregated_attributions[target_class].append(attributions) - # For each target class concatenate node attributions accross batches - processed_attributions = [] + # For each target class concatenate node attributions accross batches + processed_attributions = [] # Process each class for class_idx in range(len(aggregated_attributions)): class_attr = aggregated_attributions[class_idx] @@ -478,37 +580,57 @@ def bytes_to_gb(bytes): attr_concat = torch.cat([batch_attr for batch_attr in class_attr], dim=1) processed_attributions.append(attr_concat) - # compute absolute importance and move to cpu - abs_attr = [torch.abs(attr_class).cpu() for attr_class in processed_attributions] - # average over samples + # compute absolute importance and move to cpu + abs_attr = [ + torch.abs(attr_class).cpu() for attr_class in processed_attributions + ] + # average over samples imp = [a.mean(dim=1) for a in abs_attr] # move the model also back to cpu (if not already on cpu) - self.to('cpu') - + self.to("cpu") + # Report memory usage based on device type - if device.type == 'cuda': - print("Memory after batch processing: {:.3f} MB".format(bytes_to_gb(torch.cuda.max_memory_reserved()))) + if device.type == "cuda": + print( + "Memory after batch processing: {:.3f} MB".format( + bytes_to_gb(torch.cuda.max_memory_reserved()) + ) + ) # MPS and CPU don't have detailed memory tracking like CUDA df_list = [] - layers = list(getattr(dataset, 'multiomic_dataset', dataset).dat.keys()) + layers = list(getattr(dataset, "multiomic_dataset", dataset).dat.keys()) for i in range(num_class): features = dataset.common_features - target_class_label = dataset.label_mappings[target_var].get(i) if target_var in dataset.label_mappings else '' - for l in range(len(layers)): + target_class_label = ( + dataset.label_mappings[target_var].get(i) + if target_var in dataset.label_mappings + else "" + ) + for layer_idx, layer_name in enumerate(layers): # Extracting node feature attributes coming from different omic layers importances_array = imp[i].squeeze().detach().numpy() if importances_array.ndim == 1: - importances = importances_array # Use the array as is if it is 1-dimensional + importances = ( + importances_array # Use the array as is if it is 1-dimensional + ) else: - importances = importances_array[:, l] # Use the original indexing for 2D arrays - df_list.append(pd.DataFrame({'target_variable': target_var, - 'target_class': i, - 'target_class_label': target_class_label, - 'layer': layers[l], - 'name': features, - 'importance': importances})) + importances = importances_array[ + :, layer_idx + ] # Use the original indexing for 2D arrays + df_list.append( + pd.DataFrame( + { + "target_variable": target_var, + "target_class": i, + "target_class_label": target_class_label, + "layer": layer_name, + "name": features, + "importance": importances, + } + ) + ) df_imp = pd.concat(df_list, ignore_index=True) # save the computed scores in the model - self.feature_importances[target_var] = df_imp \ No newline at end of file + self.feature_importances[target_var] = df_imp diff --git a/flexynesis/models/supervised_vae.py b/flexynesis/models/supervised_vae.py index 57609e9c..24f45981 100644 --- a/flexynesis/models/supervised_vae.py +++ b/flexynesis/models/supervised_vae.py @@ -1,24 +1,21 @@ # Supervised VAE-MMD architecture +import itertools + +import lightning as pl +import numpy as np +import pandas as pd import torch -import itertools +from captum.attr import GradientShap, IntegratedGradients from torch import nn from torch.nn import functional as F -from torch.utils.data import Dataset, DataLoader, random_split +from torch.utils.data import DataLoader -import pandas as pd -import numpy as np - -import lightning as pl -from scipy import stats +from ..modules import MLP, Decoder, Encoder, cox_ph_loss -from captum.attr import IntegratedGradients, GradientShap -from ..utils import to_device_safe -from ..modules import * - -# Supervised Variational Auto-encoder that can train one or more layers of omics datasets +# Supervised Variational Auto-encoder that can train one or more layers of omics datasets # num_layers: number of omics layers in the input -# each layer is encoded separately, encodings are concatenated, and decoded separately +# each layer is encoded separately, encodings are concatenated, and decoded separately # depends on MLP, Encoder, Decoder classes in models_shared class supervised_vae(pl.LightningModule): """ @@ -28,7 +25,7 @@ class supervised_vae(pl.LightningModule): Each omics layer is encoded separately using an Encoder. The resulting latent representations are then concatenated and passed through a fully connected network (fusion layer) to make predictions. The model also can be attached to one ore more supervisor heads for regression/classification/survival tasks. - In the absence of supervisor heads, it can be used for unsupervised learning. + In the absence of supervisor heads, it can be used for unsupervised learning. Attributes: config (dict): Configuration settings for the model, including learning rates and dimensions. @@ -40,9 +37,18 @@ class supervised_vae(pl.LightningModule): use_loss_weighting (bool, optional): Whether to use loss weighting in the model. Defaults to True. device_type (str, optional): Type of device to run the model ('gpu' or 'cpu'). Defaults to None. """ - def __init__(self, config, dataset, target_variables, batch_variables = None, - surv_event_var = None, surv_time_var = None, use_loss_weighting = True, - device_type = None): + + def __init__( + self, + config, + dataset, + target_variables, + batch_variables=None, + surv_event_var=None, + surv_time_var=None, + use_loss_weighting=True, + device_type=None, + ): super(supervised_vae, self).__init__() self.config = config self.dataset = dataset @@ -54,50 +60,74 @@ def __init__(self, config, dataset, target_variables, batch_variables = None, if self.surv_event_var is not None and self.surv_time_var is not None: self.target_variables = self.target_variables + [self.surv_event_var] self.batch_variables = batch_variables - self.variables = self.target_variables + batch_variables if batch_variables else self.target_variables + self.variables = ( + self.target_variables + batch_variables + if batch_variables + else self.target_variables + ) self.feature_importances = {} - + # sometimes the model may have exploding/vanishing gradients leading to NaN values - self.nan_detected = False + self.nan_detected = False self.device_type = device_type self.use_loss_weighting = use_loss_weighting - + if self.use_loss_weighting: # Initialize log variance parameters for uncertainty weighting self.log_vars = nn.ParameterDict() - for loss_type in itertools.chain(self.variables, ['mmd_loss']): + for loss_type in itertools.chain(self.variables, ["mmd_loss"]): self.log_vars[loss_type] = nn.Parameter(torch.zeros(1)) - + layers = list(dataset.dat.keys()) input_dims = [len(dataset.features[layers[i]]) for i in range(len(layers))] # create a list of Encoder instances for separately encoding each omics layer - self.encoders = nn.ModuleList([Encoder(input_dims[i], - # define hidden_dim size as a factor of input_dim; keep at least 2 units - [max(int(input_dims[i] * config['hidden_dim_factor']), 2)], - config['latent_dim']) for i in range(len(layers))]) + self.encoders = nn.ModuleList( + [ + Encoder( + input_dims[i], + # define hidden_dim size as a factor of input_dim; keep at least 2 units + [max(int(input_dims[i] * config["hidden_dim_factor"]), 2)], + config["latent_dim"], + ) + for i in range(len(layers)) + ] + ) # Fully connected layers for concatenated means and log_vars - self.FC_mean = nn.Linear(len(layers) * config['latent_dim'], config['latent_dim']) - self.FC_log_var = nn.Linear(len(layers) * config['latent_dim'], config['latent_dim']) - # list of decoders to decode each omics layer separately - self.decoders = nn.ModuleList([Decoder(config['latent_dim'], - # define hidden_dim size as a factor of input_dim; keep at least 2 units - [max(int(input_dims[i] * config['hidden_dim_factor']),2)], - input_dims[i]) for i in range(len(layers))]) + self.FC_mean = nn.Linear( + len(layers) * config["latent_dim"], config["latent_dim"] + ) + self.FC_log_var = nn.Linear( + len(layers) * config["latent_dim"], config["latent_dim"] + ) + # list of decoders to decode each omics layer separately + self.decoders = nn.ModuleList( + [ + Decoder( + config["latent_dim"], + # define hidden_dim size as a factor of input_dim; keep at least 2 units + [max(int(input_dims[i] * config["hidden_dim_factor"]), 2)], + input_dims[i], + ) + for i in range(len(layers)) + ] + ) # define supervisor heads # using ModuleDict to store multiple MLPs - self.MLPs = nn.ModuleDict() + self.MLPs = nn.ModuleDict() for var in self.variables: - if self.dataset.variable_types[var] == 'numerical': + if self.dataset.variable_types[var] == "numerical": num_class = 1 else: num_class = len(np.unique(self.dataset.ann[var])) - self.MLPs[var] = MLP(input_dim = config['latent_dim'], - hidden_dim = config['supervisor_hidden_dim'], - output_dim = num_class) - + self.MLPs[var] = MLP( + input_dim=config["latent_dim"], + hidden_dim=config["supervisor_hidden_dim"], + output_dim=num_class, + ) + def multi_encoder(self, x_list): """ Encode each input matrix separately using the corresponding Encoder. @@ -122,7 +152,7 @@ def multi_encoder(self, x_list): mean = self.FC_mean(torch.cat(means, dim=1)) log_var = self.FC_log_var(torch.cat(log_vars, dim=1)) return mean, log_var - + def forward(self, x_list): """ Forward pass through the model. @@ -139,20 +169,20 @@ def forward(self, x_list): - y_pred (torch.Tensor): Predicted output. """ mean, log_var = self.multi_encoder(x_list) - + # generate latent layer z = self.reparameterization(mean, log_var) # Decode each latent variable with its corresponding Decoder x_hat_list = [self.decoders[i](z) for i in range(len(x_list))] - #run the supervisor heads using the latent layer as input + # run the supervisor heads using the latent layer as input outputs = {} for var, mlp in self.MLPs.items(): outputs[var] = mlp(z) - + return x_hat_list, z, mean, log_var, outputs - + def reparameterization(self, mean, var): """ Reparameterize the mean and variance values. @@ -164,10 +194,10 @@ def reparameterization(self, mean, var): Returns: torch.Tensor: Latent representation. """ - epsilon = torch.randn_like(var) - z = mean + var*epsilon + epsilon = torch.randn_like(var) + z = mean + var * epsilon return z - + def configure_optimizers(self): """ Configure the optimizer for the model. @@ -175,63 +205,69 @@ def configure_optimizers(self): Returns: torch.optim.Adam: Adam optimizer with learning rate 1e-3. """ - optimizer = torch.optim.Adam(self.parameters(), lr=self.config['lr']) + optimizer = torch.optim.Adam(self.parameters(), lr=self.config["lr"]) return optimizer - + def compute_loss(self, var, y, y_hat): """ Computes the loss for a specific variable based on whether the variable is numerical or categorical. Handles missing labels by excluding them from the loss calculation. - + Args: var (str): The name of the variable for which the loss is being calculated. y (torch.Tensor): The true labels or values for the variable. y_hat (torch.Tensor): The predicted labels or values output by the model. - + Returns: torch.Tensor: The calculated loss tensor for the variable. If there are no valid labels or values to compute the loss (all are missing), returns a zero loss tensor with gradient enabled. - + The method first checks the type of the variable (`var`) from `variable_types`. If the variable is numerical, it computes the mean squared error loss. For categorical variables, it calculates the cross-entropy loss. The method ensures to ignore any instances where the labels are missing (NaN for numerical or -1 for categorical as assumed missing value encoding) when calculating the loss. """ - if self.dataset.variable_types[var] == 'numerical': + if self.dataset.variable_types[var] == "numerical": # Ignore instances with missing labels for numerical variables valid_indices = ~torch.isnan(y) - if valid_indices.sum() > 0: # only calculate loss if there are valid targets + if ( + valid_indices.sum() > 0 + ): # only calculate loss if there are valid targets y_hat = y_hat[valid_indices] y = y[valid_indices] loss = F.mse_loss(torch.flatten(y_hat), y.float()) else: - loss = torch.tensor(0.0, device=y_hat.device, requires_grad=True) # if no valid labels, set loss to 0 + loss = torch.tensor( + 0.0, device=y_hat.device, requires_grad=True + ) # if no valid labels, set loss to 0 else: # Ignore instances with missing labels for categorical variables # Assuming that missing values were encoded as -1 valid_indices = (y != -1) & (~torch.isnan(y)) - if valid_indices.sum() > 0: # only calculate loss if there are valid targets + if ( + valid_indices.sum() > 0 + ): # only calculate loss if there are valid targets y_hat = y_hat[valid_indices] y = y[valid_indices] loss = F.cross_entropy(y_hat, y.long()) - else: + else: loss = torch.tensor(0.0, device=y_hat.device, requires_grad=True) return loss - + def compute_total_loss(self, losses): """ Computes the total loss from a dictionary of individual losses. This method can compute either weighted or unweighted total loss based on the model configuration. If loss weighting is enabled and there are multiple loss components, it uses uncertainty-based weighting. See Kendall A. et al, https://arxiv.org/abs/1705.07115. - + Args: losses (dict of torch.Tensor): A dictionary where each key is a variable name and each value is the loss tensor associated with that variable. - + Returns: torch.Tensor: The total loss computed across all inputs, either weighted or unweighted. - + The method checks if loss weighting is used (`use_loss_weighting`) and if there are multiple losses to weight. If so, it computes the weighted sum of losses, where the weight involves the exponential of the negative log variance (acting as precision) associated with each loss, @@ -240,16 +276,18 @@ def compute_total_loss(self, losses): loss component, it sums up the losses directly. """ if self.use_loss_weighting and len(losses) > 1: - # Compute weighted loss for each loss + # Compute weighted loss for each loss # Weighted loss = precision * loss + log-variance - total_loss = sum(torch.exp(-self.log_vars[name]) * loss + self.log_vars[name] for name, loss in losses.items()) + total_loss = sum( + torch.exp(-self.log_vars[name]) * loss + self.log_vars[name] + for name, loss in losses.items() + ) else: # Compute unweighted total loss total_loss = sum(losses.values()) return total_loss - - def training_step(self, train_batch, batch_idx, log = True): + def training_step(self, train_batch, batch_idx, log=True): """ Executes one training step using a single batch from the training dataset. @@ -264,36 +302,39 @@ def training_step(self, train_batch, batch_idx, log = True): dat, y_dict, samples = train_batch layers = dat.keys() x_list = [dat[x] for x in layers] - + x_hat_list, z, mean, log_var, outputs = self.forward(x_list) - + # compute mmd loss for each layer and take average - mmd_loss_list = [self.MMD_loss(z.shape[1], z, x_hat_list[i], x_list[i]) for i in range(len(layers))] + mmd_loss_list = [ + self.MMD_loss(z.shape[1], z, x_hat_list[i], x_list[i]) + for i in range(len(layers)) + ] mmd_loss = torch.mean(torch.stack(mmd_loss_list)) - # compute loss values for the supervisor heads - losses = {'mmd_loss': mmd_loss} - + # compute loss values for the supervisor heads + losses = {"mmd_loss": mmd_loss} + for var in self.variables: if var == self.surv_event_var: durations = y_dict[self.surv_time_var] - events = y_dict[self.surv_event_var] - risk_scores = outputs[var] #output of MLP + events = y_dict[self.surv_event_var] + risk_scores = outputs[var] # output of MLP loss = cox_ph_loss(risk_scores, durations, events) - else: + else: y_hat = outputs[var] y = y_dict[var] loss = self.compute_loss(var, y, y_hat) losses[var] = loss - + total_loss = self.compute_total_loss(losses) - # add total loss for logging - losses['train_loss'] = total_loss + # add total loss for logging + losses["train_loss"] = total_loss if log: self.log_dict(losses, on_step=False, on_epoch=True, prog_bar=True) return total_loss - - def validation_step(self, val_batch, batch_idx, log = True): + + def validation_step(self, val_batch, batch_idx, log=True): """ Executes one validation step using a single batch from the validation dataset. @@ -308,33 +349,36 @@ def validation_step(self, val_batch, batch_idx, log = True): dat, y_dict, samples = val_batch layers = dat.keys() x_list = [dat[x] for x in layers] - + x_hat_list, z, mean, log_var, outputs = self.forward(x_list) - + # compute mmd loss for each layer and take average - mmd_loss_list = [self.MMD_loss(z.shape[1], z, x_hat_list[i], x_list[i]) for i in range(len(layers))] + mmd_loss_list = [ + self.MMD_loss(z.shape[1], z, x_hat_list[i], x_list[i]) + for i in range(len(layers)) + ] mmd_loss = torch.mean(torch.stack(mmd_loss_list)) - # compute loss values for the supervisor heads - losses = {'mmd_loss': mmd_loss} + # compute loss values for the supervisor heads + losses = {"mmd_loss": mmd_loss} for var in self.variables: if var == self.surv_event_var: durations = y_dict[self.surv_time_var] - events = y_dict[self.surv_event_var] - risk_scores = outputs[var] #output of MLP + events = y_dict[self.surv_event_var] + risk_scores = outputs[var] # output of MLP loss = cox_ph_loss(risk_scores, durations, events) - else: + else: y_hat = outputs[var] y = y_dict[var] loss = self.compute_loss(var, y, y_hat) losses[var] = loss - + total_loss = sum(losses.values()) - losses['val_loss'] = total_loss + losses["val_loss"] = total_loss if log: self.log_dict(losses, on_step=False, on_epoch=True, prog_bar=True) return total_loss - + def transform(self, dataset): """ Transform the input dataset to latent representation using batching. @@ -347,22 +391,37 @@ def transform(self, dataset): """ self.eval() # Set the model to evaluation mode from ..utils import create_device_from_string - device = create_device_from_string(self.device_type if hasattr(self, 'device_type') and self.device_type else 'auto') + + device = create_device_from_string( + self.device_type + if hasattr(self, "device_type") and self.device_type + else "auto" + ) self.to(device) # Move the model to the appropriate device - dataloader = DataLoader(dataset, batch_size=64, shuffle=False) # Adjust the batch size as needed - all_latent_representations = [] # Initialize a list to collect all batch latent representations + dataloader = DataLoader( + dataset, batch_size=64, shuffle=False + ) # Adjust the batch size as needed + all_latent_representations = ( + [] + ) # Initialize a list to collect all batch latent representations sample_names = [] # List to collect sample names # Process each batch for batch in dataloader: dat, _, samples = batch - x_list = [dat[x].to(device) for x in dat.keys()] # Prepare the data batch for processing + x_list = [ + dat[x].to(device) for x in dat.keys() + ] # Prepare the data batch for processing # Perform the forward pass and extract the latent representation - latent_representation = self.forward(x_list)[1].detach().cpu().numpy() # Index [1] assumes second return is the latent rep + latent_representation = ( + self.forward(x_list)[1].detach().cpu().numpy() + ) # Index [1] assumes second return is the latent rep - all_latent_representations.append(latent_representation) # Store the batch's latent representation + all_latent_representations.append( + latent_representation + ) # Store the batch's latent representation sample_names.extend(samples) # Collect sample names for this batch # Concatenate all batch latent representations into one array @@ -370,11 +429,11 @@ def transform(self, dataset): # Convert the array to a DataFrame z = pd.DataFrame(concatenated_latents) - z.columns = ['E' + str(i) for i in range(z.shape[1])] # Name columns + z.columns = ["E" + str(i) for i in range(z.shape[1])] # Name columns z.index = sample_names # Set DataFrame index to sample names return z - + def predict(self, dataset): """ Evaluate the model on a dataset using batching. @@ -387,17 +446,28 @@ def predict(self, dataset): """ self.eval() # Set the model to evaluation mode from ..utils import create_device_from_string - device = create_device_from_string(self.device_type if hasattr(self, 'device_type') and self.device_type else 'auto') + + device = create_device_from_string( + self.device_type + if hasattr(self, "device_type") and self.device_type + else "auto" + ) self.to(device) # Move the model to the appropriate device - dataloader = DataLoader(dataset, batch_size=64, shuffle=False) # Adjust the batch size as needed + dataloader = DataLoader( + dataset, batch_size=64, shuffle=False + ) # Adjust the batch size as needed - predictions = {var: [] for var in self.variables} # Initialize prediction storage + predictions = { + var: [] for var in self.variables + } # Initialize prediction storage # Process each batch for batch in dataloader: dat, _, _ = batch - x_list = [dat[x].to(device) for x in dat.keys()] # Prepare the data batch for processing + x_list = [ + dat[x].to(device) for x in dat.keys() + ] # Prepare the data batch for processing # Perform the forward pass X_hat, z, mean, log_var, outputs = self.forward(x_list) @@ -405,17 +475,21 @@ def predict(self, dataset): # Collect predictions for each variable for var in self.variables: logits = outputs[var].detach().cpu() # Raw model outputs (logits) - - if dataset.variable_types[var] == 'categorical': - probs = torch.softmax(logits, dim=1).numpy() # class probabilities between 0 and 1 + + if dataset.variable_types[var] == "categorical": + probs = torch.softmax( + logits, dim=1 + ).numpy() # class probabilities between 0 and 1 predictions[var].extend(probs) else: - predictions[var].extend(logits.numpy()) # return raw output for regression problems - # Convert lists to arrays + predictions[var].extend( + logits.numpy() + ) # return raw output for regression problems + # Convert lists to arrays predictions = {var: np.array(predictions[var]) for var in predictions} return predictions - + def compute_kernel(self, x, y): """ Compute the Gaussian kernel matrix between two sets of vectors. @@ -430,12 +504,12 @@ def compute_kernel(self, x, y): x_size = x.size(0) y_size = y.size(0) dim = x.size(1) - x = x.unsqueeze(1) # (x_size, 1, dim) - y = y.unsqueeze(0) # (1, y_size, dim) + x = x.unsqueeze(1) # (x_size, 1, dim) + y = y.unsqueeze(0) # (1, y_size, dim) tiled_x = x.expand(x_size, y_size, dim) tiled_y = y.expand(x_size, y_size, dim) - kernel_input = (tiled_x - tiled_y).pow(2).mean(2)/float(dim) - return torch.exp(-kernel_input) # (x_size, y_size) + kernel_input = (tiled_x - tiled_y).pow(2).mean(2) / float(dim) + return torch.exp(-kernel_input) # (x_size, y_size) def compute_mmd(self, x, y): """ @@ -451,7 +525,7 @@ def compute_mmd(self, x, y): x_kernel = self.compute_kernel(x, x) y_kernel = self.compute_kernel(y, y) xy_kernel = self.compute_kernel(x, y) - mmd = x_kernel.mean() + y_kernel.mean() - 2*xy_kernel.mean() + mmd = x_kernel.mean() + y_kernel.mean() - 2 * xy_kernel.mean() return mmd def MMD_loss(self, latent_dim, z, xhat, x): @@ -467,27 +541,37 @@ def MMD_loss(self, latent_dim, z, xhat, x): Returns: torch.Tensor: A scalar tensor representing the MMD loss. """ - true_samples = torch.randn(200, latent_dim, device = self.device) - mmd = self.compute_mmd(true_samples, z) # compute maximum mean discrepancy (MMD) - nll = (xhat - x).pow(2).mean() #negative log likelihood - return mmd+nll - - # Adaptor forward function for captum integrated gradients. + true_samples = torch.randn(200, latent_dim, device=self.device) + mmd = self.compute_mmd( + true_samples, z + ) # compute maximum mean discrepancy (MMD) + nll = (xhat - x).pow(2).mean() # negative log likelihood + return mmd + nll + + # Adaptor forward function for captum integrated gradients. def forward_target(self, *args): input_data = list(args[:-2]) # one or more tensors (one per omics layer) target_var = args[-2] # target variable of interest - steps = args[-1] # number of steps for IntegratedGradients().attribute + steps = args[-1] # number of steps for IntegratedGradients().attribute outputs_list = [] for i in range(steps): # get list of tensors for each step into a list of tensors x_step = [input_data[j][i] for j in range(len(input_data))] x_hat_list, z, mean, log_var, outputs = self.forward(x_step) outputs_list.append(outputs[target_var]) - return torch.cat(outputs_list, dim = 0) - - def compute_feature_importance(self, dataset, target_var, method="IntegratedGradients", steps_or_samples=5, batch_size=64): + return torch.cat(outputs_list, dim=0) + + def compute_feature_importance( + self, + dataset, + target_var, + method="IntegratedGradients", + steps_or_samples=5, + batch_size=64, + ): """ - Computes the feature importance for each variable in the dataset using either Integrated Gradients or Gradient SHAP. + Computes feature importance for each variable in the dataset using either + Integrated Gradients or Gradient SHAP. Args: dataset: The dataset object containing the features and data. @@ -502,26 +586,40 @@ def compute_feature_importance(self, dataset, target_var, method="IntegratedGrad pd.DataFrame: A DataFrame containing feature importances across different variables and data modalities. """ from ..utils import create_device_from_string, to_device_safe - device = create_device_from_string(self.device_type if hasattr(self, 'device_type') and self.device_type else 'auto') - + + device = create_device_from_string( + self.device_type + if hasattr(self, "device_type") and self.device_type + else "auto" + ) + # Force CPU for Captum feature importance, as MPS lacks the required float64 support. - if device.type == 'mps': - print("[WARNING] MPS device detected. Computing feature importance on CPU due to MPS float64 incompatibility.") - device = torch.device('cpu') - + if device.type == "mps": + print( + "[WARNING] MPS device detected. Computing feature importance on CPU due to MPS float64 incompatibility." + ) + device = torch.device("cpu") + self.to(device) - - print("[INFO] Computing feature importance for variable:",target_var,"on device:",device) + + print( + "[INFO] Computing feature importance for variable:", + target_var, + "on device:", + device, + ) # Choose the attribution method dynamically if method == "IntegratedGradients": explainer = IntegratedGradients(self.forward_target) elif method == "GradientShap": explainer = GradientShap(self.forward_target) else: - raise ValueError(f"Unsupported method '{method}'. Choose 'IntegratedGradients' or 'GradientShap'.") + raise ValueError( + f"Unsupported method '{method}'. Choose 'IntegratedGradients' or 'GradientShap'." + ) # Get the number of classes for the target variable - if self.dataset.variable_types[target_var] == 'numerical': + if self.dataset.variable_types[target_var] == "numerical": num_class = 1 else: num_class = len(np.unique(self.dataset.ann[target_var])) @@ -529,48 +627,69 @@ def compute_feature_importance(self, dataset, target_var, method="IntegratedGrad dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False) aggregated_attributions = [[] for _ in range(num_class)] - + for batch in dataloader: dat, _, _ = batch x_list = [to_device_safe(dat[x], device) for x in dat.keys()] input_data = tuple([data.unsqueeze(0).requires_grad_() for data in x_list]) - - if method == 'IntegratedGradients': + + if method == "IntegratedGradients": baseline = tuple(torch.zeros_like(x) for x in input_data) - elif method == 'GradientShap': # provide multiple baselines for Gr.Shap + elif method == "GradientShap": # provide multiple baselines for Gr.Shap baseline = tuple( - torch.cat([torch.zeros_like(x) for _ in range(steps_or_samples)], dim=0) + torch.cat( + [torch.zeros_like(x) for _ in range(steps_or_samples)], + dim=0, + ) for x in input_data ) - + if num_class == 1: - if method == 'IntegratedGradients': - attributions = explainer.attribute(input_data, baseline, - additional_forward_args=(target_var, steps_or_samples), - n_steps=steps_or_samples) - elif method == 'GradientShap': - attributions = explainer.attribute(input_data, baseline, - additional_forward_args=(target_var, steps_or_samples), - n_samples=steps_or_samples) + if method == "IntegratedGradients": + attributions = explainer.attribute( + input_data, + baseline, + additional_forward_args=(target_var, steps_or_samples), + n_steps=steps_or_samples, + ) + elif method == "GradientShap": + attributions = explainer.attribute( + input_data, + baseline, + additional_forward_args=(target_var, steps_or_samples), + n_samples=steps_or_samples, + ) aggregated_attributions[0].append(attributions) else: for target_class in range(num_class): - if method == 'IntegratedGradients': - attributions = explainer.attribute(input_data, baseline, - additional_forward_args=(target_var, steps_or_samples), - target=target_class, - n_steps=steps_or_samples) - elif method == 'GradientShap': - attributions = explainer.attribute(input_data, baseline, - additional_forward_args=(target_var, steps_or_samples), - target=target_class, - n_samples=steps_or_samples) + if method == "IntegratedGradients": + attributions = explainer.attribute( + input_data, + baseline, + additional_forward_args=( + target_var, + steps_or_samples, + ), + target=target_class, + n_steps=steps_or_samples, + ) + elif method == "GradientShap": + attributions = explainer.attribute( + input_data, + baseline, + additional_forward_args=( + target_var, + steps_or_samples, + ), + target=target_class, + n_samples=steps_or_samples, + ) aggregated_attributions[target_class].append(attributions) - # For each target class and for each data modality/layer, concatenate attributions accross batches + # For each target class and for each data modality/layer, concatenate attributions accross batches layers = list(dataset.dat.keys()) num_layers = len(layers) - processed_attributions = [] + processed_attributions = [] # Process each class for class_idx in range(len(aggregated_attributions)): class_attr = aggregated_attributions[class_idx] @@ -583,32 +702,44 @@ def compute_feature_importance(self, dataset, target_var, method="IntegratedGrad attr_concat = torch.cat(layer_tensors, dim=1) layer_attributions.append(attr_concat) processed_attributions.append(layer_attributions) - + # summarize feature importances # Compute absolute attributions # Move the processed tensors to CPU for further operations that are not supported on GPU - abs_attr = [[torch.abs(a).cpu() for a in attr_class] for attr_class in processed_attributions] - # average over samples + abs_attr = [ + [torch.abs(a).cpu() for a in attr_class] + for attr_class in processed_attributions + ] + # average over samples imp = [[a.mean(dim=1) for a in attr_class] for attr_class in abs_attr] # move the model also back to cpu (if not already on cpu) - self.to('cpu') + self.to("cpu") - # combine into a single data frame + # combine into a single data frame df_list = [] for i in range(num_class): for j in range(len(layers)): features = self.dataset.features[layers[j]] importances = imp[i][j][0].detach().numpy() - target_class_label = dataset.label_mappings[target_var].get(i) if target_var in dataset.label_mappings else '' - df_list.append(pd.DataFrame({'target_variable': target_var, - 'target_class': i, - 'target_class_label': target_class_label, - 'layer': layers[j], - 'name': features, 'importance': importances})) - df_imp = pd.concat(df_list, ignore_index = True) - + target_class_label = ( + dataset.label_mappings[target_var].get(i) + if target_var in dataset.label_mappings + else "" + ) + df_list.append( + pd.DataFrame( + { + "target_variable": target_var, + "target_class": i, + "target_class_label": target_class_label, + "layer": layers[j], + "name": features, + "importance": importances, + } + ) + ) + df_imp = pd.concat(df_list, ignore_index=True) + # save scores in model self.feature_importances[target_var] = df_imp - - diff --git a/flexynesis/models/triplet_encoder.py b/flexynesis/models/triplet_encoder.py index f714dcc9..4ba8b60b 100644 --- a/flexynesis/models/triplet_encoder.py +++ b/flexynesis/models/triplet_encoder.py @@ -1,22 +1,17 @@ -# Generating encodings of multi-omic data using triplet loss -import torch -from torch import nn -from torch.nn import functional as F -from torch.utils.data import Dataset, DataLoader, random_split - +# Generating encodings of multi-omic data using triplet loss import itertools -import pandas as pd -import numpy as np - import lightning as pl +import numpy as np +import pandas as pd +import torch +from captum.attr import GradientShap, IntegratedGradients +from torch import nn +from torch.nn import functional as F +from torch.utils.data import DataLoader -from ..utils import to_device_safe -from ..modules import * from ..data import TripletMultiOmicDataset - -from captum.attr import IntegratedGradients, GradientShap - +from ..modules import MLP, cox_ph_loss class MultiTripletNetwork(pl.LightningModule): @@ -37,11 +32,19 @@ class MultiTripletNetwork(pl.LightningModule): device_type (str, optional): Type of device ('gpu' or 'cpu') on which the model will be run. """ - def __init__(self, config, dataset, target_variables, batch_variables = None, - surv_event_var = None, surv_time_var = None, use_loss_weighting = True, - device_type = None): + def __init__( + self, + config, + dataset, + target_variables, + batch_variables=None, + surv_event_var=None, + surv_time_var=None, + use_loss_weighting=True, + device_type=None, + ): super(MultiTripletNetwork, self).__init__() - + self.config = config self.target_variables = target_variables self.surv_event_var = surv_event_var @@ -51,54 +54,73 @@ def __init__(self, config, dataset, target_variables, batch_variables = None, if self.surv_event_var is not None and self.surv_time_var is not None: self.target_variables = self.target_variables + [self.surv_event_var] self.batch_variables = batch_variables - self.variables = self.target_variables + batch_variables if batch_variables else self.target_variables + self.variables = ( + self.target_variables + batch_variables + if batch_variables + else self.target_variables + ) self.ann = dataset.ann self.variable_types = dataset.variable_types self.feature_importances = {} self.device_type = device_type - # The first target variable is the main variable that dictates the triplets - # it has to be a categorical variable - self.main_var = self.target_variables[0] - if self.variable_types[self.main_var] == 'numerical': - raise ValueError("The first target variable",self.main_var," must be a categorical variable") - + # The first target variable is the main variable that dictates the triplets + # it has to be a categorical variable + self.main_var = self.target_variables[0] + if self.variable_types[self.main_var] == "numerical": + raise ValueError( + "The first target variable", + self.main_var, + " must be a categorical variable", + ) + self.use_loss_weighting = use_loss_weighting - + if self.use_loss_weighting: # Initialize log variance parameters for uncertainty weighting self.log_vars = nn.ParameterDict() - for loss_type in itertools.chain(self.variables, ['triplet_loss']): + for loss_type in itertools.chain(self.variables, ["triplet_loss"]): self.log_vars[loss_type] = nn.Parameter(torch.zeros(1)) - + self.layers = list(dataset.dat.keys()) - self.input_dims = [len(dataset.features[self.layers[i]]) for i in range(len(self.layers))] - - self.encoders = nn.ModuleList([ - MLP(input_dim=self.input_dims[i], - # define hidden_dim size relative to the input_dim size - hidden_dim=int(self.input_dims[i] * self.config['hidden_dim_factor']), - output_dim=self.config['latent_dim']) for i in range(len(self.layers))]) - + self.input_dims = [ + len(dataset.features[self.layers[i]]) for i in range(len(self.layers)) + ] + + self.encoders = nn.ModuleList( + [ + MLP( + input_dim=self.input_dims[i], + # define hidden_dim size relative to the input_dim size + hidden_dim=int( + self.input_dims[i] * self.config["hidden_dim_factor"] + ), + output_dim=self.config["latent_dim"], + ) + for i in range(len(self.layers)) + ] + ) + if len(self.input_dims) > 1: self.fusion_block = nn.Linear( - in_features=self.config['latent_dim'] * len(self.layers), - out_features=self.config['latent_dim'] + in_features=self.config["latent_dim"] * len(self.layers), + out_features=self.config["latent_dim"], ) else: self.fusion_block = None - - # define supervisor heads for both target and batch variables - self.MLPs = nn.ModuleDict() # using ModuleDict to store multiple MLPs + # define supervisor heads for both target and batch variables + self.MLPs = nn.ModuleDict() # using ModuleDict to store multiple MLPs for var in self.variables: - if self.variable_types[var] == 'numerical': + if self.variable_types[var] == "numerical": num_class = 1 else: num_class = len(np.unique(self.ann[var])) - self.MLPs[var] = MLP(input_dim=self.config['latent_dim'], - hidden_dim=self.config['supervisor_hidden_dim'], - output_dim=num_class) - + self.MLPs[var] = MLP( + input_dim=self.config["latent_dim"], + hidden_dim=self.config["supervisor_hidden_dim"], + output_dim=num_class, + ) + def concat_embeddings(self, dat): embeddings_list = [] x_list = [dat[x] for x in dat.keys()] @@ -106,10 +128,14 @@ def concat_embeddings(self, dat): for i, x in enumerate(x_list): embeddings_list.append(self.encoders[i](x)) embeddings_concat = torch.cat(embeddings_list, dim=1) - # if multiple embeddings, fuse them - embeddings = self.fusion_block(embeddings_concat) if self.fusion_block else embeddings_concat + # if multiple embeddings, fuse them + embeddings = ( + self.fusion_block(embeddings_concat) + if self.fusion_block + else embeddings_concat + ) return embeddings - + def forward(self, anchor, positive, negative): """ Compute the forward pass of the MultiTripletNetwork and return the embeddings and predictions. @@ -126,13 +152,18 @@ def forward(self, anchor, positive, negative): anchor_embedding = self.concat_embeddings(anchor) positive_embedding = self.concat_embeddings(positive) negative_embedding = self.concat_embeddings(negative) - - #run the supervisor heads using the anchor embeddings as input + + # run the supervisor heads using the anchor embeddings as input outputs = {} for var, mlp in self.MLPs.items(): outputs[var] = mlp(anchor_embedding) - return anchor_embedding, positive_embedding, negative_embedding, outputs - + return ( + anchor_embedding, + positive_embedding, + negative_embedding, + outputs, + ) + def configure_optimizers(self): """ Configure the optimizer for the MultiTripletNetwork. @@ -140,9 +171,9 @@ def configure_optimizers(self): Returns: torch.optim.Optimizer: The configured optimizer. """ - optimizer = torch.optim.Adam(self.parameters(), lr=self.config['lr']) + optimizer = torch.optim.Adam(self.parameters(), lr=self.config["lr"]) return optimizer - + def triplet_loss(self, anchor, positive, negative, margin=1.0): """ Compute the triplet loss for the given anchor, positive, and negative embeddings. @@ -159,63 +190,69 @@ def triplet_loss(self, anchor, positive, negative, margin=1.0): distance_positive = (anchor - positive).pow(2).sum(1) distance_negative = (anchor - negative).pow(2).sum(1) losses = torch.relu(distance_positive - distance_negative + margin) - return losses.mean() + return losses.mean() def compute_loss(self, var, y, y_hat): """ Computes the loss for a specific variable based on whether the variable is numerical or categorical. Handles missing labels by excluding them from the loss calculation. - + Args: var (str): The name of the variable for which the loss is being calculated. y (torch.Tensor): The true labels or values for the variable. y_hat (torch.Tensor): The predicted labels or values output by the model. - + Returns: torch.Tensor: The calculated loss tensor for the variable. If there are no valid labels or values to compute the loss (all are missing), returns a zero loss tensor with gradient enabled. - + The method first checks the type of the variable (`var`) from `variable_types`. If the variable is numerical, it computes the mean squared error loss. For categorical variables, it calculates the cross-entropy loss. The method ensures to ignore any instances where the labels are missing (NaN for numerical or -1 for categorical as assumed missing value encoding) when calculating the loss. """ - if self.variable_types[var] == 'numerical': + if self.variable_types[var] == "numerical": # Ignore instances with missing labels for numerical variables valid_indices = ~torch.isnan(y) - if valid_indices.sum() > 0: # only calculate loss if there are valid targets + if ( + valid_indices.sum() > 0 + ): # only calculate loss if there are valid targets y_hat = y_hat[valid_indices] y = y[valid_indices] loss = F.mse_loss(torch.flatten(y_hat), y.float()) else: - loss = torch.tensor(0.0, device=y_hat.device, requires_grad=True) # if no valid labels, set loss to 0 + loss = torch.tensor( + 0.0, device=y_hat.device, requires_grad=True + ) # if no valid labels, set loss to 0 else: # Ignore instances with missing labels for categorical variables # Assuming that missing values were encoded as -1 valid_indices = (y != -1) & (~torch.isnan(y)) - if valid_indices.sum() > 0: # only calculate loss if there are valid targets + if ( + valid_indices.sum() > 0 + ): # only calculate loss if there are valid targets y_hat = y_hat[valid_indices] y = y[valid_indices] loss = F.cross_entropy(y_hat, y.long()) - else: + else: loss = torch.tensor(0.0, device=y_hat.device, requires_grad=True) return loss - + def compute_total_loss(self, losses): """ Computes the total loss from a dictionary of individual losses. This method can compute either weighted or unweighted total loss based on the model configuration. If loss weighting is enabled and there are multiple loss components, it uses uncertainty-based weighting. See Kendall A. et al, https://arxiv.org/abs/1705.07115. - + Args: losses (dict of torch.Tensor): A dictionary where each key is a variable name and each value is the loss tensor associated with that variable. - + Returns: torch.Tensor: The total loss computed across all inputs, either weighted or unweighted. - + The method checks if loss weighting is used (`use_loss_weighting`) and if there are multiple losses to weight. If so, it computes the weighted sum of losses, where the weight involves the exponential of the negative log variance (acting as precision) associated with each loss, @@ -224,42 +261,59 @@ def compute_total_loss(self, losses): loss component, it sums up the losses directly. """ if self.use_loss_weighting and len(losses) > 1: - # Compute weighted loss for each loss + # Compute weighted loss for each loss # Weighted loss = precision * loss + log-variance - total_loss = sum(torch.exp(-self.log_vars[name]) * loss + self.log_vars[name] for name, loss in losses.items()) + total_loss = sum( + torch.exp(-self.log_vars[name]) * loss + self.log_vars[name] + for name, loss in losses.items() + ) else: # Compute unweighted total loss total_loss = sum(losses.values()) return total_loss - def training_step(self, train_batch, batch_idx, log = True): + def training_step(self, train_batch, batch_idx, log=True): """ Perform a training step using a single batch of data, including triplet components and target labels. - + Args: - train_batch (tuple): The batch containing data tuples (anchor, positive, negative) and a dictionary of labels. + train_batch (tuple): The batch containing data tuples (anchor, positive, + negative) and a dictionary of labels. batch_idx (int): The index of the current batch. - log (bool, optional): Flag to determine if logging should occur at each step. Defaults to True. - + log (bool, optional): Flag to determine if logging should occur at each + step. Defaults to True. + Returns: - torch.Tensor: The total loss for the current training batch, which includes triplet loss and any additional - losses from supervisor heads. - - This method computes the embedding for the anchor, positive, and negative samples and calculates the triplet loss. - Additional losses are computed for other target variables in the dataset, particularly handling survival analysis - if applicable. All losses are combined to compute a total loss, which is logged and returned. + torch.Tensor: The total loss for the current training batch, which + includes triplet loss and any additional losses from supervisor + heads. + + This method computes the embedding for the anchor, positive, and negative + samples and calculates the triplet loss. Additional losses are computed for + other target variables in the dataset, particularly handling survival + analysis if applicable. All losses are combined to compute a total loss, + which is logged and returned. """ - anchor, positive, negative, y_dict = train_batch[0], train_batch[1], train_batch[2], train_batch[3] - anchor_embedding, positive_embedding, negative_embedding, outputs = self.forward(anchor, positive, negative) - triplet_loss = self.triplet_loss(anchor_embedding, positive_embedding, negative_embedding) - - # compute loss values for the supervisor heads - losses = {'triplet_loss': triplet_loss} + anchor, positive, negative, y_dict = ( + train_batch[0], + train_batch[1], + train_batch[2], + train_batch[3], + ) + anchor_embedding, positive_embedding, negative_embedding, outputs = ( + self.forward(anchor, positive, negative) + ) + triplet_loss = self.triplet_loss( + anchor_embedding, positive_embedding, negative_embedding + ) + + # compute loss values for the supervisor heads + losses = {"triplet_loss": triplet_loss} for var in self.variables: if var == self.surv_event_var: durations = y_dict[self.surv_time_var] - events = y_dict[self.surv_event_var] - risk_scores = outputs[var] #output of MLP + events = y_dict[self.surv_event_var] + risk_scores = outputs[var] # output of MLP loss = cox_ph_loss(risk_scores, durations, events) else: y_hat = outputs[var] @@ -268,69 +322,83 @@ def training_step(self, train_batch, batch_idx, log = True): losses[var] = loss total_loss = self.compute_total_loss(losses) - # add total loss for logging - losses['train_loss'] = total_loss + # add total loss for logging + losses["train_loss"] = total_loss if log: self.log_dict(losses, on_step=False, on_epoch=True, prog_bar=True) return total_loss - - def validation_step(self, val_batch, batch_idx, log = True): + + def validation_step(self, val_batch, batch_idx, log=True): """ Perform a validation step using a single batch of data, including triplet components and target labels. - + Args: - val_batch (tuple): The batch containing data tuples (anchor, positive, negative) and a dictionary of labels. + val_batch (tuple): The batch containing data tuples (anchor, positive, + negative) and a dictionary of labels. batch_idx (int): The index of the current batch. - log (bool, optional): Flag to determine if logging should occur at each step. Defaults to True. - + log (bool, optional): Flag to determine if logging should occur at each + step. Defaults to True. + Returns: - torch.Tensor: The total loss for the current validation batch, which includes triplet loss and any additional - losses from supervisor heads. - - Similar to the training step, this method computes the embedding for the anchor, positive, and negative samples - and calculates the triplet loss. It computes additional losses for other target variables in the dataset, aggregates - all losses, and returns the total loss. The losses are logged if specified. + torch.Tensor: The total loss for the current validation batch, which + includes triplet loss and any additional losses from supervisor + heads. + + Similar to the training step, this method computes the embedding for the + anchor, positive, and negative samples and calculates the triplet loss. It + computes additional losses for other target variables in the dataset, + aggregates all losses, and returns the total loss. The losses are logged if + specified. """ - anchor, positive, negative, y_dict = val_batch[0], val_batch[1], val_batch[2], val_batch[3] - anchor_embedding, positive_embedding, negative_embedding, outputs = self.forward(anchor, positive, negative) - triplet_loss = self.triplet_loss(anchor_embedding, positive_embedding, negative_embedding) - - # compute loss values for the supervisor heads - losses = {'triplet_loss': triplet_loss} + anchor, positive, negative, y_dict = ( + val_batch[0], + val_batch[1], + val_batch[2], + val_batch[3], + ) + anchor_embedding, positive_embedding, negative_embedding, outputs = ( + self.forward(anchor, positive, negative) + ) + triplet_loss = self.triplet_loss( + anchor_embedding, positive_embedding, negative_embedding + ) + + # compute loss values for the supervisor heads + losses = {"triplet_loss": triplet_loss} for var in self.variables: if var == self.surv_event_var: durations = y_dict[self.surv_time_var] - events = y_dict[self.surv_event_var] - risk_scores = outputs[var] #output of MLP + events = y_dict[self.surv_event_var] + risk_scores = outputs[var] # output of MLP loss = cox_ph_loss(risk_scores, durations, events) else: y_hat = outputs[var] y = y_dict[var] loss = self.compute_loss(var, y, y_hat) losses[var] = loss - + total_loss = sum(losses.values()) - losses['val_loss'] = total_loss + losses["val_loss"] = total_loss if log: self.log_dict(losses, on_step=False, on_epoch=True, prog_bar=True) return total_loss - + # dataset: MultiOmicDataset def transform(self, dataset): """ Transforms the input dataset by generating embeddings and predictions. - + Args: dataset (MultiOmicDataset): An instance of the MultiOmicDataset class. - + Returns: z (pd.DataFrame): A dataframe containing the computed embeddings. y_pred (np.ndarray): A numpy array containing the predicted labels. """ self.eval() - # get anchor embeddings + # get anchor embeddings z = pd.DataFrame(self.concat_embeddings(dataset.dat).detach().numpy()) - z.columns = [''.join(['E', str(x)]) for x in z.columns] + z.columns = ["".join(["E", str(x)]) for x in z.columns] z.index = dataset.samples return z @@ -342,7 +410,8 @@ def predict(self, dataset): dataset: The dataset to evaluate the model on. Returns: - A dictionary where each key is a target variable and the corresponding value is the predicted output for that variable. + A dictionary where each key is a target variable and the corresponding + value is the predicted output for that variable. """ self.eval() # get anchor embedding @@ -351,46 +420,64 @@ def predict(self, dataset): outputs = {} for var, mlp in self.MLPs.items(): outputs[var] = mlp(anchor_embedding) - + # get predictions from the mlp outputs for each var - predictions = {var: [] for var in self.variables} # Initialize prediction storage + predictions = { + var: [] for var in self.variables + } # Initialize prediction storage # Collect predictions for each variable for var in self.variables: logits = outputs[var].detach().cpu() # Raw model outputs (logits) - if dataset.variable_types[var] == 'categorical': - probs = torch.softmax(logits, dim=1).numpy() # class probabilities between 0 and 1 + if dataset.variable_types[var] == "categorical": + probs = torch.softmax( + logits, dim=1 + ).numpy() # class probabilities between 0 and 1 predictions[var].extend(probs) else: - predictions[var].extend(logits.numpy()) # return raw output for regression problems - # Convert lists to arrays + predictions[var].extend( + logits.numpy() + ) # return raw output for regression problems + # Convert lists to arrays predictions = {var: np.array(predictions[var]) for var in predictions} - + return predictions - - # Adaptor forward function for captum integrated gradients. - # layer_sizes: number of features in each omic layer + # Adaptor forward function for captum integrated gradients. + # layer_sizes: number of features in each omic layer def forward_target(self, input_data, layer_sizes, target_var, steps): outputs_list = [] for i in range(steps): - # for each step, get anchor/positive/negative tensors + # for each step, get anchor/positive/negative tensors # (split the concatenated omics layers) - anchor = input_data[i][0].split(layer_sizes, dim = 1) - positive = input_data[i][1].split(layer_sizes, dim = 1) - negative = input_data[i][2].split(layer_sizes, dim = 1) - + anchor = input_data[i][0].split(layer_sizes, dim=1) + positive = input_data[i][1].split(layer_sizes, dim=1) + negative = input_data[i][2].split(layer_sizes, dim=1) + # convert to dict anchor = {k: anchor[k] for k in range(len(anchor))} - positive = {k: anchor[k] for k in range(len(positive))} - negative = {k: anchor[k] for k in range(len(negative))} - anchor_embedding, positive_embedding, negative_embedding, outputs = self.forward(anchor, positive, negative) + positive = {k: positive[k] for k in range(len(positive))} + negative = {k: negative[k] for k in range(len(negative))} + ( + anchor_embedding, + positive_embedding, + negative_embedding, + outputs, + ) = self.forward(anchor, positive, negative) outputs_list.append(outputs[target_var]) - return torch.cat(outputs_list, dim = 0) - - def compute_feature_importance(self, dataset, target_var, method="IntegratedGradients", steps_or_samples=5, batch_size=64): + return torch.cat(outputs_list, dim=0) + + def compute_feature_importance( + self, + dataset, + target_var, + method="IntegratedGradients", + steps_or_samples=5, + batch_size=64, + ): """ - Computes the feature importance for each variable in the dataset using either Integrated Gradients or Gradient SHAP. + Computes the feature importance for each variable in the dataset using + either Integrated Gradients or Gradient SHAP. Args: dataset: The dataset object containing the features and data. @@ -406,94 +493,147 @@ def compute_feature_importance(self, dataset, target_var, method="IntegratedGrad """ from ..utils import create_device_from_string, to_device_safe - device = create_device_from_string(self.device_type if hasattr(self, 'device_type') and self.device_type else 'auto') - + + device = create_device_from_string( + self.device_type + if hasattr(self, "device_type") and self.device_type + else "auto" + ) + # Force CPU for Captum feature importance, as MPS lacks the required float64 support. - if device.type == 'mps': - print("[WARNING] MPS device detected. Computing feature importance on CPU due to MPS float64 incompatibility.") - device = torch.device('cpu') - + if device.type == "mps": + print( + "[WARNING] MPS device detected. Computing feature importance on CPU due to MPS float64 incompatibility." + ) + device = torch.device("cpu") + self.to(device) - print("[INFO] Computing feature importance for variable:",target_var,"on device:",device) + print( + "[INFO] Computing feature importance for variable:", + target_var, + "on device:", + device, + ) - # define data loader + # define data loader triplet_dataset = TripletMultiOmicDataset(dataset, self.main_var) dataloader = DataLoader(triplet_dataset, batch_size=batch_size, shuffle=False) - + # Choose the attribution method dynamically if method == "IntegratedGradients": explainer = IntegratedGradients(self.forward_target) elif method == "GradientShap": explainer = GradientShap(self.forward_target) else: - raise ValueError(f"Unsupported method '{method}'. Choose 'IntegratedGradients' or 'GradientShap'.") + raise ValueError( + f"Unsupported method '{method}'. Choose 'IntegratedGradients' or 'GradientShap'." + ) - if self.variable_types[target_var] == 'numerical': + if self.variable_types[target_var] == "numerical": num_class = 1 else: num_class = len(np.unique([y[target_var] for _, y, _ in dataset])) aggregated_attributions = [[] for _ in range(num_class)] for batch in dataloader: - # see training_step to see how elements are accessed in batches - anchor, positive, negative, y_dict = batch[0], batch[1], batch[2], batch[3] + # see training_step to see how elements are accessed in batches + anchor, positive, negative = ( + batch[0], + batch[1], + batch[2], + ) # Move tensors to the specified device using MPS-safe method anchor = {k: to_device_safe(v, device) for k, v in anchor.items()} positive = {k: to_device_safe(v, device) for k, v in positive.items()} negative = {k: to_device_safe(v, device) for k, v in negative.items()} - + anchor = [data.requires_grad_() for data in list(anchor.values())] positive = [data.requires_grad_() for data in list(positive.values())] negative = [data.requires_grad_() for data in list(negative.values())] - + # concatenate multiomic layers of each list element - # then stack the anchor/positive/negative + # then stack the anchor/positive/negative # the purpose is to get a single tensor - input_data = torch.stack([torch.cat(sublist, dim = 1) for sublist in [anchor, positive, negative]]).unsqueeze(0) + input_data = torch.stack( + [torch.cat(sublist, dim=1) for sublist in [anchor, positive, negative]] + ).unsqueeze(0) - # layer sizes will be needed to revert the concatenated tensor + # layer sizes will be needed to revert the concatenated tensor # anchor/positive/negative have the same shape - layer_sizes = [anchor[i].shape[1] for i in range(len(anchor))] - - # Define a baseline - if method == 'IntegratedGradients': - baseline = torch.zeros_like(input_data) - elif method == 'GradientShap': # provide multiple baselines for Gr.Shap - baseline = torch.cat([torch.zeros_like(input_data) for _ in range(steps_or_samples)], dim=0) - + layer_sizes = [anchor[i].shape[1] for i in range(len(anchor))] + + # Define a baseline + if method == "IntegratedGradients": + baseline = torch.zeros_like(input_data) + elif method == "GradientShap": # provide multiple baselines for Gr.Shap + baseline = torch.cat( + [torch.zeros_like(input_data) for _ in range(steps_or_samples)], + dim=0, + ) + if num_class == 1: - - if method == 'IntegratedGradients': - attributions = explainer.attribute(input_data, baseline, - additional_forward_args=(layer_sizes, target_var, steps_or_samples), - n_steps=steps_or_samples) - attributions = attributions.split(layer_sizes, dim = 3) - elif method == 'GradientShape': - attributions = explainer.attribute(input_data, baseline, - additional_forward_args=(layer_sizes, target_var, steps_or_samples), - n_samples=steps_or_samples) - attributions = attributions.split(layer_sizes, dim = 3) + + if method == "IntegratedGradients": + attributions = explainer.attribute( + input_data, + baseline, + additional_forward_args=( + layer_sizes, + target_var, + steps_or_samples, + ), + n_steps=steps_or_samples, + ) + attributions = attributions.split(layer_sizes, dim=3) + elif method == "GradientShape": + attributions = explainer.attribute( + input_data, + baseline, + additional_forward_args=( + layer_sizes, + target_var, + steps_or_samples, + ), + n_samples=steps_or_samples, + ) + attributions = attributions.split(layer_sizes, dim=3) aggregated_attributions[0].append(attributions) else: for target_class in range(num_class): - if method == 'IntegratedGradients': - attributions = explainer.attribute(input_data, baseline, - additional_forward_args=(layer_sizes, target_var, steps_or_samples), - target=target_class, n_steps=steps_or_samples) - attributions = attributions.split(layer_sizes, dim = 3) - elif method == 'GradientShap': - attributions = explainer.attribute(input_data, baseline, - additional_forward_args=(layer_sizes, target_var, steps_or_samples), - target=target_class, n_samples=steps_or_samples) - attributions = attributions.split(layer_sizes, dim = 3) + if method == "IntegratedGradients": + attributions = explainer.attribute( + input_data, + baseline, + additional_forward_args=( + layer_sizes, + target_var, + steps_or_samples, + ), + target=target_class, + n_steps=steps_or_samples, + ) + attributions = attributions.split(layer_sizes, dim=3) + elif method == "GradientShap": + attributions = explainer.attribute( + input_data, + baseline, + additional_forward_args=( + layer_sizes, + target_var, + steps_or_samples, + ), + target=target_class, + n_samples=steps_or_samples, + ) + attributions = attributions.split(layer_sizes, dim=3) aggregated_attributions[target_class].append(attributions) - # For each target class and for each data modality/layer, concatenate attributions accross batches + # For each target class and for each data modality/layer, concatenate attributions accross batches layers = self.layers num_layers = len(layers) - processed_attributions = [] + processed_attributions = [] # Process each class for class_idx in range(len(aggregated_attributions)): class_attr = aggregated_attributions[class_idx] @@ -507,25 +647,40 @@ def compute_feature_importance(self, dataset, target_var, method="IntegratedGrad layer_attributions.append(attr_concat) processed_attributions.append(layer_attributions) - # compute absolute importance and move to cpu + # compute absolute importance and move to cpu # notice the squeeze (due to triplets) - abs_attr = [[torch.abs(a.squeeze()).cpu() for a in attr_class] for attr_class in processed_attributions] - # average over samples + abs_attr = [ + [torch.abs(a.squeeze()).cpu() for a in attr_class] + for attr_class in processed_attributions + ] + # average over samples imp = [[a.mean(dim=1) for a in attr_class] for attr_class in abs_attr] # move the model also back to cpu (if not already on cpu) - self.to('cpu') + self.to("cpu") df_list = [] for i in range(num_class): for j in range(len(layers)): features = dataset.features[layers[j]] # Ensure tensors are already on CPU before converting to numpy - importances = imp[i][j][0].detach().numpy() # 0 => extract importances only for the anchor - target_class_label = dataset.label_mappings[target_var].get(i) if target_var in dataset.label_mappings else '' - df_list.append(pd.DataFrame({'target_variable': target_var, - 'target_class': i, - 'target_class_label': target_class_label, - 'layer': layers[j], - 'name': features, - 'importance': importances})) + importances = ( + imp[i][j][0].detach().numpy() + ) # 0 => extract importances only for the anchor + target_class_label = ( + dataset.label_mappings[target_var].get(i) + if target_var in dataset.label_mappings + else "" + ) + df_list.append( + pd.DataFrame( + { + "target_variable": target_var, + "target_class": i, + "target_class_label": target_class_label, + "layer": layers[j], + "name": features, + "importance": importances, + } + ) + ) df_imp = pd.concat(df_list, ignore_index=True) self.feature_importances[target_var] = df_imp diff --git a/flexynesis/modules.py b/flexynesis/modules.py index aed8ac83..c1d2722d 100644 --- a/flexynesis/modules.py +++ b/flexynesis/modules.py @@ -2,8 +2,7 @@ import torch from torch import nn -from torch_geometric.nn import aggr, GCNConv, GATConv, SAGEConv, GraphConv - +from torch_geometric.nn import GATConv, GCNConv, GraphConv, SAGEConv __all__ = ["Encoder", "Decoder", "MLP", "flexGCN", "cox_ph_loss"] @@ -11,59 +10,61 @@ class Encoder(nn.Module): """ Encoder class for a Variational Autoencoder (VAE). - + The Encoder class is responsible for taking input data and generating the mean and log variance for the latent space representation. """ + def __init__(self, input_dim, hidden_dims, latent_dim): super(Encoder, self).__init__() self.act = nn.LeakyReLU(0.2) - + hidden_layers = [] - + hidden_layers.append(nn.Linear(input_dim, hidden_dims[0])) nn.init.xavier_uniform_(hidden_layers[-1].weight) hidden_layers.append(self.act) hidden_layers.append(nn.BatchNorm1d(hidden_dims[0])) - for i in range(len(hidden_dims)-1): - hidden_layers.append(nn.Linear(hidden_dims[i], hidden_dims[i+1])) + for i in range(len(hidden_dims) - 1): + hidden_layers.append(nn.Linear(hidden_dims[i], hidden_dims[i + 1])) nn.init.xavier_uniform_(hidden_layers[-1].weight) hidden_layers.append(self.act) - hidden_layers.append(nn.BatchNorm1d(hidden_dims[i+1])) + hidden_layers.append(nn.BatchNorm1d(hidden_dims[i + 1])) self.hidden_layers = nn.Sequential(*hidden_layers) - - self.FC_mean = nn.Linear(hidden_dims[-1], latent_dim) + + self.FC_mean = nn.Linear(hidden_dims[-1], latent_dim) nn.init.xavier_uniform_(self.FC_mean.weight) - self.FC_var = nn.Linear(hidden_dims[-1], latent_dim) + self.FC_var = nn.Linear(hidden_dims[-1], latent_dim) nn.init.xavier_uniform_(self.FC_var.weight) - + def forward(self, x): """ Performs a forward pass through the Encoder network. - + Args: x (torch.Tensor): The input data tensor. - + Returns: mean (torch.Tensor): The mean of the latent space representation. log_var (torch.Tensor): The log variance of the latent space representation. """ - h_ = self.hidden_layers(x) - mean = self.FC_mean(h_) - log_var = self.FC_var(h_) + h_ = self.hidden_layers(x) + mean = self.FC_mean(h_) + log_var = self.FC_var(h_) return mean, log_var - - + + class Decoder(nn.Module): """ Decoder class for a Variational Autoencoder (VAE). - + The Decoder class is responsible for taking the latent space representation and generating the reconstructed output data. """ + def __init__(self, latent_dim, hidden_dims, output_dim): super(Decoder, self).__init__() @@ -80,7 +81,7 @@ def __init__(self, latent_dim, hidden_dims, output_dim): hidden_layers.append(nn.Linear(hidden_dims[i], hidden_dims[i + 1])) nn.init.xavier_uniform_(hidden_layers[-1].weight) hidden_layers.append(self.act) - hidden_layers.append(nn.BatchNorm1d(hidden_dims[i+1])) + hidden_layers.append(nn.BatchNorm1d(hidden_dims[i + 1])) self.hidden_layers = nn.Sequential(*hidden_layers) @@ -90,49 +91,54 @@ def __init__(self, latent_dim, hidden_dims, output_dim): def forward(self, x): """ Performs a forward pass through the Decoder network. - + Args: x (torch.Tensor): The input tensor representing the latent space. - + Returns: x_hat (torch.Tensor): The reconstructed output tensor. """ h = self.hidden_layers(x) x_hat = torch.sigmoid(self.FC_output(h)) return x_hat - + class MLP(nn.Module): """ A Multi-Layer Perceptron (MLP) model for regression or classification tasks. - + The MLP class is a simple feed-forward neural network that can be used for regression when `output_dim` is set to 1 or for classification when `output_dim` is greater than 1. """ + def __init__(self, input_dim, hidden_dim, output_dim): """ Initializes the MLP class with the given input dimension, output dimension, and hidden layer size. - + Args: input_dim (int): The input dimension. hidden_dim (int, optional): The size of the hidden layer. Default is 32. output_dim (int): The output dimension. Set to 1 for regression tasks, and > 1 for classification tasks. """ super().__init__() - hidden_dim = max(hidden_dim, 2) # make sure there are at least 2 units + hidden_dim = max(hidden_dim, 2) # make sure there are at least 2 units self.layer_1 = nn.Linear(input_dim, hidden_dim) - self.layer_out = nn.Linear(hidden_dim, output_dim) if output_dim > 1 else nn.Linear(hidden_dim, 1, bias=False) - self.relu = nn.ReLU() + self.layer_out = ( + nn.Linear(hidden_dim, output_dim) + if output_dim > 1 + else nn.Linear(hidden_dim, 1, bias=False) + ) + self.relu = nn.ReLU() self.dropout = nn.Dropout(p=0.1) self.batchnorm = nn.BatchNorm1d(hidden_dim) def forward(self, x): """ Performs a forward pass through the MLP network. - + Args: x (torch.Tensor): The input data tensor. - + Returns: x (torch.Tensor): The output tensor after passing through the MLP network. """ @@ -142,7 +148,8 @@ def forward(self, x): x = self.dropout(x) x = self.layer_out(x) return x - + + class flexGCN(nn.Module): """ A Graph Neural Network (GNN) model using configurable convolution and activation layers. @@ -166,57 +173,78 @@ class flexGCN(nn.Module): output_dim (int): The size of the output vector, which is the final feature vector for the whole graph. num_convs (int, optional): Number of convolutional layers in the network. Defaults to 2. dropout_rate (float, optional): The dropout probability used for regularization. Defaults to 0.2. - conv (str, optional): Type of convolution layer to use. Supported types include 'GCN' for Graph Convolution Network, - 'GAT' for Graph Attention Network, 'SAGE' for GraphSAGE, and 'GC' for generic Graph Convolution. + conv (str, optional): Type of convolution layer to use. Supported types + include 'GCN' for Graph Convolution Network, + 'GAT' for Graph Attention Network, 'SAGE' for + GraphSAGE, and 'GC' for generic Graph Convolution. Defaults to 'GC'. - act (str, optional): Type of activation function to use. Supported types include 'relu', 'sigmoid', - 'leakyrelu', 'tanh', and 'gelu'. Defaults to 'relu'. + act (str, optional): Type of activation function to use. Supported types + include 'relu', 'sigmoid', 'leakyrelu', 'tanh', + and 'gelu'. Defaults to 'relu'. Raises: ValueError: If an unsupported activation function or convolution type is specified. Example: - >>> model = flexGCN(node_count=100, node_feature_count=5, node_embedding_dim=64, output_dim=10, + >>> model = flexGCN(node_count=100, node_feature_count=5, node_embedding_dim=64, output_dim=10, num_convs=3, dropout_rate=0.3, conv='GAT', act='relu') >>> output = model(input_features, edge_index) # Where `input_features` is a tensor of shape (batch_size, num_nodes, node_feature_count) # and `edge_index` is a list of edges in the COO format (2, num_edges). """ - def __init__(self, node_count, node_feature_count, node_embedding_dim, output_dim, - num_convs = 2, dropout_rate = 0.2, conv='GC', act='relu'): + + def __init__( + self, + node_count, + node_feature_count, + node_embedding_dim, + output_dim, + num_convs=2, + dropout_rate=0.2, + conv="GC", + act="relu", + ): super().__init__() act_options = { - 'relu': nn.ReLU(), - 'sigmoid': nn.Sigmoid(), - 'leakyrelu': nn.LeakyReLU(), - 'tanh': nn.Tanh(), - 'gelu': nn.GELU() + "relu": nn.ReLU(), + "sigmoid": nn.Sigmoid(), + "leakyrelu": nn.LeakyReLU(), + "tanh": nn.Tanh(), + "gelu": nn.GELU(), } if act not in act_options: - raise ValueError("Invalid activation function string. Choose from ", list(act_options.keys())) - + raise ValueError( + "Invalid activation function string. Choose from ", + list(act_options.keys()), + ) + conv_options = { - 'GCN': GCNConv, - 'GAT': GATConv, - 'SAGE': SAGEConv, - 'GC': GraphConv + "GCN": GCNConv, + "GAT": GATConv, + "SAGE": SAGEConv, + "GC": GraphConv, } if conv not in conv_options: - raise ValueError('Unknown convolution type. Choose one of: ', list(conv_options.keys())) + raise ValueError( + "Unknown convolution type. Choose one of: ", + list(conv_options.keys()), + ) self.act = act_options[act] self.convs = nn.ModuleList() self.bns = nn.ModuleList() self.dropout = nn.Dropout(dropout_rate) - + # Initialize the first convolution layer separately if different input size self.convs.append(conv_options[conv](node_feature_count, node_embedding_dim)) self.bns.append(nn.BatchNorm1d(node_embedding_dim)) # Loop to create the remaining convolution and BN layers for _ in range(1, num_convs): - self.convs.append(conv_options[conv](node_embedding_dim, node_embedding_dim)) + self.convs.append( + conv_options[conv](node_embedding_dim, node_embedding_dim) + ) self.bns.append(nn.BatchNorm1d(node_embedding_dim)) # Final fully connected layer @@ -227,13 +255,14 @@ def forward(self, x, edge_index): x = conv(x, edge_index) x = bn(x.view(-1, x.size(2))).view_as(x) x = self.act(x) - x = self.dropout(x) + x = self.dropout(x) # Flatten the output of all nodes into a single vector per graph/sample x = x.view(x.size(0), -1) x = self.fc(x) return x + def cox_ph_loss(outputs, durations, events): """ Calculate the Cox proportional hazards loss. @@ -251,22 +280,26 @@ def cox_ph_loss(outputs, durations, events): outputs = outputs[valid_indices] events = events[valid_indices] durations = durations[valid_indices] - + # Exponentiate the outputs to get the hazard ratios hazards = torch.exp(outputs) # Ensure hazards is at least 1D if hazards.dim() == 0: hazards = hazards.unsqueeze(0) # Make hazards 1D if it's a scalar # Calculate the risk set sum - log_risk_set_sum = torch.log(torch.cumsum(hazards[torch.argsort(durations, descending=True)], dim=0)) + log_risk_set_sum = torch.log( + torch.cumsum(hazards[torch.argsort(durations, descending=True)], dim=0) + ) # Get the indices that sort the durations in descending order sorted_indices = torch.argsort(durations, descending=True) events_sorted = events[sorted_indices] # Calculate the loss - uncensored_loss = torch.sum(outputs[sorted_indices][events_sorted == 1]) - torch.sum(log_risk_set_sum[events_sorted == 1]) + uncensored_loss = torch.sum( + outputs[sorted_indices][events_sorted == 1] + ) - torch.sum(log_risk_set_sum[events_sorted == 1]) total_loss = -uncensored_loss / torch.sum(events) - else: + else: total_loss = torch.tensor(0.0, device=outputs.device, requires_grad=True) if not torch.isfinite(total_loss): return torch.tensor(0.0, device=outputs.device, requires_grad=True) diff --git a/flexynesis/utils.py b/flexynesis/utils.py index 901ac8ab..c08a76cb 100644 --- a/flexynesis/utils.py +++ b/flexynesis/utils.py @@ -1,37 +1,31 @@ -from lightning import seed_everything -import pandas as pd -import numpy as np -import torch -import math -import warnings -import requests -import tarfile import os -from glob import glob -import re -import logging -from tqdm import tqdm +import tarfile +import warnings +from collections import deque -from umap import UMAP -import seaborn as sns import matplotlib.pyplot as plt -import matplotlib +import numpy as np +import pandas as pd +import requests +import seaborn as sns +import torch +from scipy.stats import kruskal, linregress, mannwhitneyu, pearsonr from sklearn.decomposition import PCA -from sklearn.metrics import balanced_accuracy_score, f1_score, cohen_kappa_score, classification_report, roc_auc_score, average_precision_score -from sklearn.metrics import mean_squared_error -from sklearn.metrics import adjusted_mutual_info_score, adjusted_rand_score -from sklearn.utils import resample -from sksurv.metrics import cumulative_dynamic_auc -from sksurv.util import Surv - -from scipy.stats import pearsonr, linregress - - from sklearn.ensemble import RandomForestClassifier, RandomForestRegressor +from sklearn.feature_selection import (SelectFromModel, mutual_info_classif, + mutual_info_regression) +from sklearn.metrics import (adjusted_mutual_info_score, adjusted_rand_score, + average_precision_score, balanced_accuracy_score, + classification_report, cohen_kappa_score, + f1_score, mean_squared_error, + precision_recall_curve, roc_auc_score, roc_curve) +from sklearn.model_selection import GridSearchCV, KFold +from sklearn.neighbors import NearestNeighbors +from sklearn.preprocessing import label_binarize from sklearn.svm import SVC, SVR -from sklearn.feature_selection import SelectFromModel -from sklearn.feature_selection import mutual_info_regression, mutual_info_classif -from sklearn.model_selection import KFold, cross_val_score, GridSearchCV +from sksurv.metrics import cumulative_dynamic_auc +from sksurv.util import Surv +from umap import UMAP try: from xgboost import XGBClassifier, XGBRegressor @@ -39,44 +33,27 @@ XGBClassifier = None XGBRegressor = None -from sksurv.ensemble import RandomSurvivalForest -from sksurv.metrics import concordance_index_censored - -from lifelines import KaplanMeierFitter -from lifelines.utils import concordance_index -from lifelines import CoxPHFitter +import community as community_louvain +import networkx as nx +import ot +from lifelines import CoxPHFitter, KaplanMeierFitter from lifelines.statistics import logrank_test, multivariate_logrank_test - -from plotnine import ( - ggplot, aes, geom_point, geom_smooth, geom_line, geom_abline, geom_step, - labs, ggtitle, annotate, theme_minimal, theme, element_text, - scale_color_manual, scale_color_gradient, scale_color_brewer, - geom_errorbarh, geom_text, - theme_bw, theme, element_blank, scale_y_discrete -) - +from lifelines.utils import concordance_index +from plotnine import (aes, annotate, element_text, geom_abline, geom_errorbarh, + geom_line, geom_point, geom_smooth, geom_step, geom_text, + ggplot, ggtitle, labs, scale_color_gradient, + scale_color_manual, theme, theme_bw, theme_minimal) from sklearn.cluster import KMeans from sklearn.metrics import silhouette_score from sklearn.metrics.pairwise import euclidean_distances -import networkx as nx -import community as community_louvain - from sklearn.preprocessing import StandardScaler -import ot - +from sksurv.ensemble import RandomSurvivalForest +from sksurv.metrics import concordance_index_censored -# imports -import numpy as np -import pandas as pd -import matplotlib.pyplot as plt -from plotnine import ( - ggplot, aes, geom_point, geom_step, labs, ggtitle, annotate, - theme_minimal, theme, element_text, scale_color_manual, scale_color_gradient -) -from sklearn.decomposition import PCA -from umap import UMAP -from lifelines import KaplanMeierFitter -from lifelines.statistics import logrank_test, multivariate_logrank_test +try: + from geomloss import SamplesLoss +except Exception: + SamplesLoss = None def _labels_to_1d(labels): @@ -91,6 +68,7 @@ def _labels_to_1d(labels): arr = np.asarray(labels) return np.ravel(arr) + def get_color_mapping(labels): """ Map categorical labels to colors using ALPHABETICAL order (deterministic). @@ -102,17 +80,17 @@ def get_color_mapping(labels): n = len(unique_labels) if n <= 10: - cmap = plt.get_cmap('tab10', n) + cmap = plt.get_cmap("tab10", n) colors = [cmap(i) for i in range(n)] elif n <= 20: - cmap = plt.get_cmap('tab20', n) + cmap = plt.get_cmap("tab20", n) colors = [cmap(i) for i in range(n)] else: # stack multiple palettes to cover many categories palettes = [ - plt.get_cmap('tab20', 20), - plt.get_cmap('Dark2', 8), - plt.get_cmap('Accent', 8), + plt.get_cmap("tab20", 20), + plt.get_cmap("Dark2", 8), + plt.get_cmap("Accent", 8), ] colors = [] for pal in palettes: @@ -122,25 +100,33 @@ def get_color_mapping(labels): colors.extend(colors) colors = colors[:n] - to_hex = lambda c: '#%02x%02x%02x' % (int(c[0]*255), int(c[1]*255), int(c[2]*255)) + def to_hex(c): + return "#%02x%02x%02x" % ( + int(c[0] * 255), + int(c[1] * 255), + int(c[2] * 255), + ) + color_hex = [to_hex(c) for c in colors] return dict(zip(unique_labels, color_hex)) -def plot_dim_reduced(matrix, labels, method='pca', color_type='categorical', title=None): +def plot_dim_reduced( + matrix, labels, method="pca", color_type="categorical", title=None +): """ Plot first two dims (PCA/UMAP). Uses alphabetical label ordering + shared palette. """ method = method.lower() - if method == 'pca': + if method == "pca": transformer = PCA(n_components=2) transformed = transformer.fit_transform(matrix) var_exp = transformer.explained_variance_ratio_ * 100 xlab = f"PC1 ({var_exp[0]:.1f}%)" ylab = f"PC2 ({var_exp[1]:.1f}%)" - xcol, ycol = 'PC1', 'PC2' - elif method == 'umap': + xcol, ycol = "PC1", "PC2" + elif method == "umap": transformer = UMAP(n_components=2) m = np.array(matrix, dtype=np.float32) try: @@ -148,7 +134,7 @@ def plot_dim_reduced(matrix, labels, method='pca', color_type='categorical', tit except TypeError: transformed = transformer.fit_transform(m) xlab, ylab = "UMAP1", "UMAP2" - xcol, ycol = 'UMAP1', 'UMAP2' + xcol, ycol = "UMAP1", "UMAP2" else: raise ValueError("Invalid method. Expected 'pca' or 'umap'.") @@ -162,20 +148,20 @@ def plot_dim_reduced(matrix, labels, method='pca', color_type='categorical', tit plot_title = title if title else f"{method.upper()} Scatter Plot" - if color_type == 'categorical': + if color_type == "categorical": p = ( - ggplot(df, aes(x=xcol, y=ycol, color='Label')) + ggplot(df, aes(x=xcol, y=ycol, color="Label")) + geom_point() + scale_color_manual(values=color_mapping) + labs(title=plot_title, x=xlab, y=ylab, color="Labels") + theme_minimal() ) - elif color_type == 'numerical': + elif color_type == "numerical": # numerical coloring ignores the categorical palette df_num = df.copy() - df_num["Label"] = pd.to_numeric(lbls, errors='coerce') + df_num["Label"] = pd.to_numeric(lbls, errors="coerce") p = ( - ggplot(df_num, aes(x=xcol, y=ycol, color='Label')) + ggplot(df_num, aes(x=xcol, y=ycol, color="Label")) + geom_point() + scale_color_gradient(low="blue", high="red") + labs(title=plot_title, x=xlab, y=ylab, color="Label") @@ -191,11 +177,15 @@ def plot_kaplan_meier_curves(durations, events, categorical_variable): """ Kaplan–Meier curves with alphabetical label ordering + shared palette. """ - data = pd.DataFrame({ - 'Duration': _labels_to_1d(durations), - 'Event': _labels_to_1d(events), - 'Group': pd.Series(_labels_to_1d(categorical_variable), dtype="object").astype(str) - }) + data = pd.DataFrame( + { + "Duration": _labels_to_1d(durations), + "Event": _labels_to_1d(events), + "Group": pd.Series( + _labels_to_1d(categorical_variable), dtype="object" + ).astype(str), + } + ) # shared palette + fixed legend order color_mapping = get_color_mapping(data["Group"]) @@ -205,49 +195,58 @@ def plot_kaplan_meier_curves(durations, events, categorical_variable): kmf = KaplanMeierFitter() survival_curves = [] for g in order: # iterate in the same alphabetical order - gd = data[data['Group'] == g] + gd = data[data["Group"] == g] if len(gd) == 0: continue - kmf.fit(gd['Duration'], gd['Event'], label=g) + kmf.fit(gd["Duration"], gd["Event"], label=g) surv_df = kmf.survival_function_.reset_index() - surv_df.columns = ['Time', 'Survival'] - surv_df['Group'] = g + surv_df.columns = ["Time", "Survival"] + surv_df["Group"] = g survival_curves.append(surv_df) plot_data = pd.concat(survival_curves, ignore_index=True) - plot_data["Group"] = pd.Categorical(plot_data["Group"], categories=order, ordered=True) + plot_data["Group"] = pd.Categorical( + plot_data["Group"], categories=order, ordered=True + ) # log-rank text - categories = pd.unique(data['Group']) + categories = pd.unique(data["Group"]) if len(categories) == 2: g1, g2 = categories[0], categories[1] - grp1 = data[data['Group'] == g1] - grp2 = data[data['Group'] == g2] - result = logrank_test(grp1['Duration'], grp2['Duration'], - event_observed_A=grp1['Event'], - event_observed_B=grp2['Event']) + grp1 = data[data["Group"] == g1] + grp2 = data[data["Group"] == g2] + result = logrank_test( + grp1["Duration"], + grp2["Duration"], + event_observed_A=grp1["Event"], + event_observed_B=grp2["Event"], + ) p_text = f"Log-rank p = {result.p_value:.2e}" elif len(categories) > 2: - result = multivariate_logrank_test(data['Duration'], data['Group'], data['Event']) + result = multivariate_logrank_test( + data["Duration"], data["Group"], data["Event"] + ) p_text = f"Multivariate log-rank p = {result.p_value:.2e}" else: p_text = "Only one group — log-rank test not applicable" p = ( - ggplot(plot_data, aes(x='Time', y='Survival', color='Group')) + ggplot(plot_data, aes(x="Time", y="Survival", color="Group")) + geom_step() - + labs(x='Time', y='Survival Probability', color='Group') - + ggtitle('Kaplan-Meier Survival Curves by Group') - + annotate("text", x=0.1, y=0.1, label=p_text, size=10, ha='left') + + labs(x="Time", y="Survival Probability", color="Group") + + ggtitle("Kaplan-Meier Survival Curves by Group") + + annotate("text", x=0.1, y=0.1, label=p_text, size=10, ha="left") + theme_minimal() - + theme(legend_title=element_text(size=10, weight='bold')) + + theme(legend_title=element_text(size=10, weight="bold")) + scale_color_manual(values=color_mapping) ) return p + def plot_scatter(true_values, predicted_values): """ - Plots a scatterplot of true vs predicted values, with a regression line and annotated with the Pearson correlation coefficient. + Plots a scatterplot of true vs predicted values, with a regression line and + annotated with the Pearson correlation coefficient. Args: true_values (list or np.array): True values @@ -267,28 +266,43 @@ def plot_scatter(true_values, predicted_values): corr_text = f"Pearson r: {corr:.2f}" # Create DataFrame - df = pd.DataFrame({"True Values": true_values, "Predicted Values": predicted_values}) + df = pd.DataFrame( + {"True Values": true_values, "Predicted Values": predicted_values} + ) # Generate scatter plot with regression line plot = ( - ggplot(df, aes(x="True Values", y="Predicted Values")) + - geom_point(alpha=0.5) + - geom_smooth(method='lm', color='red') + - annotate("text", x=min(true_values), y=max(predicted_values), label=corr_text, ha='left', va='top', size=10) + - labs( + ggplot(df, aes(x="True Values", y="Predicted Values")) + + geom_point(alpha=0.5) + + geom_smooth(method="lm", color="red") + + annotate( + "text", + x=min(true_values), + y=max(predicted_values), + label=corr_text, + ha="left", + va="top", + size=10, + ) + + labs( title="True vs Predicted Values", x="True Values", - y="Predicted Values" - ) + - theme_minimal() + y="Predicted Values", + ) + + theme_minimal() ) return plot -from scipy.stats import mannwhitneyu, kruskal - -def plot_boxplot(categorical_x, numerical_y, title_x='Categories', title_y='Values', figsize=(10, 6), jittersize = 4): +def plot_boxplot( + categorical_x, + numerical_y, + title_x="Categories", + title_y="Values", + figsize=(10, 6), + jittersize=4, +): df = pd.DataFrame({title_x: categorical_x, title_y: numerical_y}) # Compute p-value @@ -296,7 +310,7 @@ def plot_boxplot(categorical_x, numerical_y, title_x='Categories', title_y='Valu if len(groups) == 2: group1 = df[df[title_x] == groups[0]][title_y] group2 = df[df[title_x] == groups[1]][title_y] - stat, p = mannwhitneyu(group1, group2, alternative='two-sided') + stat, p = mannwhitneyu(group1, group2, alternative="two-sided") test_name = "Mann-Whitney U" else: group_data = [df[df[title_x] == group][title_y] for group in groups] @@ -305,8 +319,25 @@ def plot_boxplot(categorical_x, numerical_y, title_x='Categories', title_y='Valu # Create a boxplot with jittered points plt.figure(figsize=figsize) - sns.boxplot(x=title_x, y=title_y, hue=title_x, data=df, palette='Set2', legend=False, fill= False) - sns.stripplot(x=title_x, y=title_y, data=df, color='black', size=jittersize, jitter=True, dodge=True, alpha=0.4) + sns.boxplot( + x=title_x, + y=title_y, + hue=title_x, + data=df, + palette="Set2", + legend=False, + fill=False, + ) + sns.stripplot( + x=title_x, + y=title_y, + data=df, + color="black", + size=jittersize, + jitter=True, + dodge=True, + alpha=0.4, + ) # Labels and p-value annotation plt.xlabel(title_x) @@ -314,21 +345,23 @@ def plot_boxplot(categorical_x, numerical_y, title_x='Categories', title_y='Valu plt.text( x=-0.4, y=plt.ylim()[1], - s=f'{test_name} p = {p:.3e}', - verticalalignment='top', - horizontalalignment='left', + s=f"{test_name} p = {p:.3e}", + verticalalignment="top", + horizontalalignment="left", fontsize=12, - bbox=dict(boxstyle='round,pad=0.3', facecolor='white', edgecolor='gray') + bbox=dict(boxstyle="round,pad=0.3", facecolor="white", edgecolor="gray"), ) plt.tight_layout() return plt.gcf() + # given a vector of numerical values which may contain # NAN values, return a binary grouping based on median values def split_by_median(v): return ((v - torch.nanmedian(v)) > 0).float() + def evaluate_survival(outputs, durations, events): """ Computes the concordance index (c-index) for survival predictions. @@ -358,14 +391,16 @@ def evaluate_survival(outputs, durations, events): # Compute concordance index (lifelines expects higher risk → lower survival) c_index = concordance_index(durations, -outputs, events) - return {'cindex': c_index} + return {"cindex": c_index} + def generate_bootstrap_indices(n, n_bootstraps=1000, seed=42): rng = np.random.default_rng(seed) return [rng.choice(n, size=n, replace=True) for _ in range(n_bootstraps)] + # bootstrapping function for regression/classification tasks -def bootstrap_metric(y_true, y_pred, indices_list, metric_fn, ci = 95, **kwargs): +def bootstrap_metric(y_true, y_pred, indices_list, metric_fn, ci=95, **kwargs): scores = [] y_true = np.array(y_true) y_pred = np.array(y_pred) @@ -377,12 +412,15 @@ def bootstrap_metric(y_true, y_pred, indices_list, metric_fn, ci = 95, **kwargs) upper = np.percentile(scores, 100 - (100 - ci) / 2) return scores, (np.mean(scores), lower, upper) + def evaluate_classifier(y_true, y_probs, print_report=False): """ - Evaluate the performance of a classifier using multiple metrics and optionally print a detailed classification report. + Evaluate the performance of a classifier using multiple metrics and optionally + print a detailed classification report. - This function computes balanced accuracy, F1 score (weighted), Cohen's Kappa score, average AUROC score, and - weighted-average AUC-PR score for the given true labels and predicted probabilities. + This function computes balanced accuracy, F1 score (weighted), Cohen's Kappa + score, average AUROC score, and weighted-average AUC-PR score for the given + true labels and predicted probabilities. If `print_report` is set to True, it prints a detailed classification report. Args: @@ -405,7 +443,7 @@ def evaluate_classifier(y_true, y_probs, print_report=False): balanced_acc = balanced_accuracy_score(y_true, y_pred) # F1 score (weighted) - f1 = f1_score(y_true, y_pred, average='weighted', zero_division=0) + f1 = f1_score(y_true, y_pred, average="weighted", zero_division=0) # Cohen's Kappa kappa = cohen_kappa_score(y_true, y_pred) @@ -415,10 +453,16 @@ def evaluate_classifier(y_true, y_probs, print_report=False): if y_probs.shape[1] == 2: # Binary classification y_probs_binary = y_probs[:, 1] # Use positive class probabilities average_auroc = roc_auc_score(y_true, y_probs_binary) - average_aupr = average_precision_score(y_true, y_probs_binary) # AUC-PR for binary case + average_aupr = average_precision_score( + y_true, y_probs_binary + ) # AUC-PR for binary case else: # Multiclass classification - average_auroc = roc_auc_score(y_true, y_probs, multi_class='ovr', average='weighted') - average_aupr = average_precision_score(y_true, y_probs, average='weighted') # Weighted AUC-PR for multiclass + average_auroc = roc_auc_score( + y_true, y_probs, multi_class="ovr", average="weighted" + ) + average_aupr = average_precision_score( + y_true, y_probs, average="weighted" + ) # Weighted AUC-PR for multiclass except ValueError: average_auroc = None # Handle cases where AUROC cannot be computed average_aupr = None # Handle cases where AUC-PR cannot be computed @@ -434,12 +478,9 @@ def evaluate_classifier(y_true, y_probs, print_report=False): "f1_score": f1, "kappa": kappa, "average_auroc": average_auroc, - "average_aupr": average_aupr # Added AUC-PR + "average_aupr": average_aupr, # Added AUC-PR } -from sklearn.metrics import roc_curve, roc_auc_score -from sklearn.preprocessing import label_binarize -from sklearn.metrics import precision_recall_curve, average_precision_score def plot_roc_curves(y_true, y_probs): """ @@ -457,7 +498,15 @@ def plot_roc_curves(y_true, y_probs): # Binary classification fpr, tpr, _ = roc_curve(y_true, y_probs[:, 1]) auc_score = roc_auc_score(y_true, y_probs[:, 1]) - plot_data.append(pd.DataFrame({'fpr': fpr, 'tpr': tpr, 'label': [f'Class 1 (AUC = {auc_score:.2f})'] * len(fpr)})) + plot_data.append( + pd.DataFrame( + { + "fpr": fpr, + "tpr": tpr, + "label": [f"Class 1 (AUC = {auc_score:.2f})"] * len(fpr), + } + ) + ) else: # Multiclass classification classes = np.arange(n_classes) @@ -466,11 +515,13 @@ def plot_roc_curves(y_true, y_probs): for i in classes: fpr, tpr, _ = roc_curve(y_true_bin[:, i], y_probs[:, i]) auc_score = roc_auc_score(y_true_bin[:, i], y_probs[:, i]) - df = pd.DataFrame({ - 'fpr': fpr, - 'tpr': tpr, - 'label': [f'Class {i} (AUC = {auc_score:.2f})'] * len(fpr) - }) + df = pd.DataFrame( + { + "fpr": fpr, + "tpr": tpr, + "label": [f"Class {i} (AUC = {auc_score:.2f})"] * len(fpr), + } + ) plot_data.append(df) # Combine all data @@ -478,19 +529,16 @@ def plot_roc_curves(y_true, y_probs): # Plot using plotnine roc_plot = ( - ggplot(all_data, aes(x='fpr', y='tpr', color='label')) + - geom_line(size=1.2) + - geom_abline(intercept=0, slope=1, linetype='dashed', color='gray') + - labs( - title='ROC Curve', - x='False Positive Rate', - y='True Positive Rate' - ) + - theme_minimal() + ggplot(all_data, aes(x="fpr", y="tpr", color="label")) + + geom_line(size=1.2) + + geom_abline(intercept=0, slope=1, linetype="dashed", color="gray") + + labs(title="ROC Curve", x="False Positive Rate", y="True Positive Rate") + + theme_minimal() ) return roc_plot + def plot_pr_curves(y_true, y_probs): """ Plot Precision-Recall (PR) curves using plotnine for binary or multiclass classification. @@ -507,24 +555,32 @@ def plot_pr_curves(y_true, y_probs): # Binary classification precision, recall, _ = precision_recall_curve(y_true, y_probs[:, 1]) aupr = average_precision_score(y_true, y_probs[:, 1]) - plot_data.append(pd.DataFrame({ - 'recall': recall, - 'precision': precision, - 'label': [f'Class 1 (AUPR = {aupr:.2f})'] * len(recall) - })) + plot_data.append( + pd.DataFrame( + { + "recall": recall, + "precision": precision, + "label": [f"Class 1 (AUPR = {aupr:.2f})"] * len(recall), + } + ) + ) else: # Multiclass classification (one-vs-rest) classes = np.arange(n_classes) y_true_bin = label_binarize(y_true, classes=classes) for i in classes: - precision, recall, _ = precision_recall_curve(y_true_bin[:, i], y_probs[:, i]) + precision, recall, _ = precision_recall_curve( + y_true_bin[:, i], y_probs[:, i] + ) aupr = average_precision_score(y_true_bin[:, i], y_probs[:, i]) - df = pd.DataFrame({ - 'recall': recall, - 'precision': precision, - 'label': [f'Class {i} (AUPR = {aupr:.2f})'] * len(recall) - }) + df = pd.DataFrame( + { + "recall": recall, + "precision": precision, + "label": [f"Class {i} (AUPR = {aupr:.2f})"] * len(recall), + } + ) plot_data.append(df) # Combine all data @@ -532,14 +588,10 @@ def plot_pr_curves(y_true, y_probs): # Plot using plotnine pr_plot = ( - ggplot(all_data, aes(x='recall', y='precision', color='label')) + - geom_line(size=1.2) + - labs( - title='Precision-Recall Curve', - x='Recall', - y='Precision' - ) + - theme_minimal() + ggplot(all_data, aes(x="recall", y="precision", color="label")) + + geom_line(size=1.2) + + labs(title="Precision-Recall Curve", x="Recall", y="Precision") + + theme_minimal() ) return pr_plot @@ -547,72 +599,95 @@ def plot_pr_curves(y_true, y_probs): def evaluate_regressor(y_true, y_pred): """ - Evaluate the performance of a regression model using mean squared error, R-squared, and Pearson correlation coefficient. + Evaluate the performance of a regression model using mean squared error, + R-squared, and Pearson correlation coefficient. - This function computes the mean squared error (MSE) between true and predicted values as a measure of prediction accuracy. - It also performs a linear regression analysis between the true and predicted values to obtain the R-squared value, which - explains the variance ratio, and the Pearson correlation coefficient, providing insight into the linear relationship strength. + This function computes the mean squared error (MSE) between true and predicted + values as a measure of prediction accuracy. It also performs a linear regression + analysis between the true and predicted values to obtain the R-squared value, + which explains the variance ratio, and the Pearson correlation coefficient, + providing insight into the linear relationship strength. Args: - y_true (array-like): True values of the dependent variable, must be a 1D list or array. - y_pred (array-like): Predicted values as returned by a regressor, must match the dimensions of y_true. + y_true (array-like): True values of the dependent variable, must be a 1D + list or array. + y_pred (array-like): Predicted values as returned by a regressor, must + match the dimensions of y_true. Returns: dict: A dictionary containing: - 'mse': The mean squared error between the true and predicted values. - - 'r2': The R-squared value indicating the proportion of variance in the dependent variable predictable from the independent variable. - - 'pearson_corr': The Pearson correlation coefficient indicating the linear relationship strength between the true and predicted values. + - 'r2': The R-squared value indicating the proportion of variance in + the dependent variable predictable from the independent variable. + - 'pearson_corr': The Pearson correlation coefficient indicating the + linear relationship strength between the true and predicted values. """ mse = mean_squared_error(y_true, y_pred) - slope, intercept, r_value, p_value, std_err = linregress(y_true,y_pred) + slope, intercept, r_value, p_value, std_err = linregress(y_true, y_pred) r2 = r_value**2 return {"mse": mse, "r2": r2, "pearson_corr": r_value} -def evaluate_wrapper(method, y_pred_dict, dataset, surv_event_var = None, surv_time_var = None): + +def evaluate_wrapper( + method, y_pred_dict, dataset, surv_event_var=None, surv_time_var=None +): """ - Evaluates predictions for different variables within a dataset using appropriate metrics based on the variable type. - Supports evaluation for numerical, categorical, and survival data. + Evaluates predictions for different variables within a dataset using + appropriate metrics based on the variable type. Supports evaluation for + numerical, categorical, and survival data. - This function loops through each variable in the predictions dictionary, determines the type of the variable, - and evaluates the predictions using the appropriate method: regression, classification, or survival analysis. - It compiles the metrics into a list of dictionaries, which is then converted into a pandas DataFrame. + This function loops through each variable in the predictions dictionary, + determines the type of the variable, and evaluates the predictions using the + appropriate method: regression, classification, or survival analysis. It + compiles the metrics into a list of dictionaries, which is then converted into + a pandas DataFrame. Args: method (str): Identifier for the prediction method or model used. - y_pred_dict (dict): A dictionary where keys are variable names and values are arrays of predicted values. - dataset (Dataset): A dataset object containing actual values and metadata such as variable types. - surv_event_var (str, optional): The name of the survival event variable. Required if survival analysis is performed. - surv_time_var (str, optional): The name of the survival time variable. Required if survival analysis is performed. + y_pred_dict (dict): A dictionary where keys are variable names and values + are arrays of predicted values. + dataset (Dataset): A dataset object containing actual values and metadata + such as variable types. + surv_event_var (str, optional): The name of the survival event variable. + Required if survival analysis is performed. + surv_time_var (str, optional): The name of the survival time variable. + Required if survival analysis is performed. Returns: - pd.DataFrame: A DataFrame where each row contains the method, variable name, variable type, metric name, and metric value. + pd.DataFrame: A DataFrame where each row contains the method, variable + name, variable type, metric name, and metric value. """ metrics_list = [] for var in y_pred_dict.keys(): - if dataset.variable_types[var] == 'numerical': + if dataset.variable_types[var] == "numerical": if var == surv_event_var: events = dataset.ann[surv_event_var] durations = dataset.ann[surv_time_var] metrics = evaluate_survival(y_pred_dict[var], durations, events) else: ind = ~torch.isnan(dataset.ann[var]) - metrics = evaluate_regressor(dataset.ann[var][ind], y_pred_dict[var][ind].flatten()) + metrics = evaluate_regressor( + dataset.ann[var][ind], y_pred_dict[var][ind].flatten() + ) else: ind = ~torch.isnan(dataset.ann[var]) metrics = evaluate_classifier(dataset.ann[var][ind], y_pred_dict[var][ind]) for metric, value in metrics.items(): - metrics_list.append({ - 'method': method, - 'var': var, - 'variable_type': dataset.variable_types[var], - 'metric': metric, - 'value': value - }) + metrics_list.append( + { + "method": method, + "var": var, + "variable_type": dataset.variable_types[var], + "metric": metric, + "value": value, + } + ) # Convert the list of metrics to a DataFrame return pd.DataFrame(metrics_list) + def get_predicted_labels(y_pred_dict, dataset, split, method_name): """ Generate a DataFrame with class probabilities and associated metadata. @@ -637,61 +712,90 @@ def get_predicted_labels(y_pred_dict, dataset, split, method_name): dfs = [] for var in y_pred_dict.keys(): - if dataset.variable_types[var] == 'categorical': + if dataset.variable_types[var] == "categorical": # Predicted probabilities probabilities = y_pred_dict[var] # Convert class indices to labels if mappings exist if var in dataset.label_mappings.keys(): - class_labels = [dataset.label_mappings[var][idx] for idx in range(probabilities.shape[1])] + class_labels = [ + dataset.label_mappings[var][idx] + for idx in range(probabilities.shape[1]) + ] else: - class_labels = [f'class_{i}' for i in range(probabilities.shape[1])] + class_labels = [f"class_{i}" for i in range(probabilities.shape[1])] # Get true labels (y_true) - y_true = [dataset.label_mappings[var][int(x.item())] if var in dataset.label_mappings.keys() and not np.isnan(x.item()) else np.nan for x in dataset.ann[var]] + y_true = [ + ( + dataset.label_mappings[var][int(x.item())] + if var in dataset.label_mappings.keys() and not np.isnan(x.item()) + else np.nan + ) + for x in dataset.ann[var] + ] # Predicted labels (argmax of probabilities) y_pred_indices = np.argmax(probabilities, axis=1) - y_pred = [dataset.label_mappings[var][idx] if var in dataset.label_mappings.keys() else idx for idx in y_pred_indices] + y_pred = [ + ( + dataset.label_mappings[var][idx] + if var in dataset.label_mappings.keys() + else idx + ) + for idx in y_pred_indices + ] # Create a DataFrame for each sample and its probabilities for i, sample_id in enumerate(dataset.samples): for j, class_label in enumerate(class_labels): - dfs.append({ - 'sample_id': sample_id, - 'variable': var, - 'class_label': class_label, - 'probability': probabilities[i, j], - 'known_label': y_true[i], - 'predicted_label': y_pred[i], - 'split': split, - 'method': method_name - }) + dfs.append( + { + "sample_id": sample_id, + "variable": var, + "class_label": class_label, + "probability": probabilities[i, j], + "known_label": y_true[i], + "predicted_label": y_pred[i], + "split": split, + "method": method_name, + } + ) else: # For numerical variables, set class_label and probability to NA y_true = [x.item() for x in dataset.ann[var]] y_pred = [x.item() for x in y_pred_dict[var]] for i, sample_id in enumerate(dataset.samples): - dfs.append({ - 'sample_id': sample_id, - 'variable': var, - 'class_label': np.nan, - 'probability': np.nan, - 'known_label': y_true[i], - 'predicted_label': y_pred[i], - 'split': split, - 'method': method_name - }) + dfs.append( + { + "sample_id": sample_id, + "variable": var, + "class_label": np.nan, + "probability": np.nan, + "known_label": y_true[i], + "predicted_label": y_pred[i], + "split": split, + "method": method_name, + } + ) # Combine all rows into a DataFrame return pd.DataFrame(dfs) - - -def evaluate_baseline_performance(train_dataset, test_dataset, variable_name, methods, n_folds=5, n_jobs=4, use_pca=False, n_components=100): +def evaluate_baseline_performance( + train_dataset, + test_dataset, + variable_name, + methods, + n_folds=5, + n_jobs=4, + use_pca=False, + n_components=100, +): """ - Evaluates the performance of machine learning models on a given variable with optional PCA for dimensionality reduction. + Evaluates the performance of machine learning models on a given variable with + optional PCA for dimensionality reduction. Args: train_dataset (Dataset): A MultiOmicDataset object containing training data and metadata. @@ -706,6 +810,7 @@ def evaluate_baseline_performance(train_dataset, test_dataset, variable_name, me Returns: pd.DataFrame: A DataFrame containing metrics for each method. """ + def prepare_data(data_object, pca_model=None, fit_pca=False): # Concatenate Data Matrices X = np.concatenate([tensor for tensor in data_object.dat.values()], axis=1) @@ -733,41 +838,63 @@ def prepare_data(data_object, pca_model=None, fit_pca=False): # Cross-Validation and Training kf = KFold(n_splits=n_folds, shuffle=True, random_state=42) - X_train, y_train, train_indices = prepare_data(train_dataset, pca_model=pca_model, fit_pca=True) + X_train, y_train, train_indices = prepare_data( + train_dataset, pca_model=pca_model, fit_pca=True + ) print("Train:", X_train.shape) - X_test, y_test, test_indices = prepare_data(test_dataset, pca_model=pca_model, fit_pca=False) + X_test, y_test, test_indices = prepare_data( + test_dataset, pca_model=pca_model, fit_pca=False + ) print("Test:", X_test.shape) metrics_list = [] predictions = [] # Collect all predictions for method in methods: - if variable_type == 'categorical': - if method == 'RandomForest': + if variable_type == "categorical": + if method == "RandomForest": model = RandomForestClassifier(random_state=42) - params = {'n_estimators': [100, 200, 300], 'max_depth': [10, 20, None]} - elif method == 'SVM': + params = { + "n_estimators": [100, 200, 300], + "max_depth": [10, 20, None], + } + elif method == "SVM": model = SVC(probability=True, random_state=42) - params = {'C': [0.1, 1, 10], 'kernel': ['rbf', 'poly']} - elif method == 'XGBoost': + params = {"C": [0.1, 1, 10], "kernel": ["rbf", "poly"]} + elif method == "XGBoost": if XGBClassifier is None: - print("[WARNING] XGBoost is not available (on macOS, run: brew install libomp). Skipping.") + print( + "[WARNING] XGBoost is not available (on macOS, run: brew install libomp). Skipping." + ) continue - model = XGBClassifier(eval_metric='logloss', random_state=42) - params = {'n_estimators': [100, 200, 300], 'max_depth': [3, 6, 9], 'learning_rate': [0.01, 0.1, 0.2]} - elif variable_type == 'numerical': - if method == 'RandomForest': + model = XGBClassifier(eval_metric="logloss", random_state=42) + params = { + "n_estimators": [100, 200, 300], + "max_depth": [3, 6, 9], + "learning_rate": [0.01, 0.1, 0.2], + } + elif variable_type == "numerical": + if method == "RandomForest": model = RandomForestRegressor(random_state=42) - params = {'n_estimators': [100, 200, 300], 'max_depth': [10, 20, None]} - elif method == 'SVM': + params = { + "n_estimators": [100, 200, 300], + "max_depth": [10, 20, None], + } + elif method == "SVM": model = SVR() - params = {'C': [0.1, 1, 10], 'kernel': ['rbf', 'poly']} - elif method == 'XGBoost': + params = {"C": [0.1, 1, 10], "kernel": ["rbf", "poly"]} + elif method == "XGBoost": if XGBRegressor is None: - print("[WARNING] XGBoost is not available (on macOS, run: brew install libomp). Skipping.") + print( + "[WARNING] XGBoost is not available (on macOS, run: brew install libomp). Skipping." + ) continue model = XGBRegressor(random_state=42) - params = {'n_estimators': [100, 200, 300], 'max_depth': [3, 6, 9], 'learning_rate': [0.01, 0.1, 0.2]} + params = { + "n_estimators": [100, 200, 300], + "max_depth": [3, 6, 9], + "learning_rate": [0.01, 0.1, 0.2], + } print("Training method:", method) grid_search = GridSearchCV(model, params, cv=kf, n_jobs=n_jobs) @@ -775,56 +902,69 @@ def prepare_data(data_object, pca_model=None, fit_pca=False): best_model = grid_search.best_estimator_ # Predict on test data - if variable_type == 'categorical': + if variable_type == "categorical": y_probs = best_model.predict_proba(X_test) metrics = evaluate_classifier(y_test, y_probs, print_report=True) y_pred_dict = {variable_name: y_probs} - elif variable_type == 'numerical': + elif variable_type == "numerical": y_pred = best_model.predict(X_test) metrics = evaluate_regressor(y_test, y_pred) y_pred_dict = {variable_name: y_pred} # need to get test indices to only consider samples with labels - df_preds = get_predicted_labels(y_pred_dict, test_dataset.subset(test_indices), 'test', method) + df_preds = get_predicted_labels( + y_pred_dict, test_dataset.subset(test_indices), "test", method + ) predictions.append(df_preds) for metric, value in metrics.items(): - metrics_list.append({ - 'method': method + ('Classifier' if variable_type == 'categorical' else 'Regressor'), - 'var': variable_name, - 'variable_type': variable_type, - 'metric': metric, - 'value': value - }) + metrics_list.append( + { + "method": method + + ("Classifier" if variable_type == "categorical" else "Regressor"), + "var": variable_name, + "variable_type": variable_type, + "metric": metric, + "value": value, + } + ) predictions = pd.concat(predictions, ignore_index=True) return pd.DataFrame(metrics_list), predictions - -def evaluate_baseline_survival_performance(train_dataset, test_dataset, duration_col, event_col, n_folds=5, n_jobs=4): +def evaluate_baseline_survival_performance( + train_dataset, test_dataset, duration_col, event_col, n_folds=5, n_jobs=4 +): """ - Evaluates the baseline performance of a Random Survival Forest model on survival data using the Concordance Index. + Evaluates the baseline performance of a Random Survival Forest model on + survival data using the Concordance Index. - The function preprocesses both training and testing datasets to prepare appropriate survival data (comprising durations - and event occurrences), performs cross-validation to assess model robustness, and then calculates the Concordance Index on - the test data. It uses a Random Survival Forest (RSF) as the predictive model. + The function preprocesses both training and testing datasets to prepare + appropriate survival data (comprising durations and event occurrences), performs + cross-validation to assess model robustness, and then calculates the Concordance + Index on the test data. It uses a Random Survival Forest (RSF) as the + predictive model. Args: train_dataset (Dataset): The training dataset (a MultiOmicDataset object) containing features and survival data. test_dataset (Dataset): The testing dataset (a MultiOmicDataset object) containing features and survival data. duration_col (str): Column name in the dataset for survival time. event_col (str): Column name in the dataset for the event occurrence (1 if event occurred, 0 otherwise). - n_folds (int, optional): Number of folds for K-fold cross-validation. Defaults to 5. - n_jobs (int, optional): Number of parallel jobs to run for Random Survival Forest training. Defaults to 4. + n_folds (int, optional): Number of folds for K-fold cross-validation. + Defaults to 5. + n_jobs (int, optional): Number of parallel jobs to run for Random Survival + Forest training. Defaults to 4. Returns: - pd.DataFrame: A DataFrame containing the performance metrics of the RSF model, specifically the Concordance Index, - listed along with the method name and variable details. + pd.DataFrame: A DataFrame containing the performance metrics of the RSF + model, specifically the Concordance Index, listed along with the method + name and variable details. """ - print(f"[INFO] Evaluating baseline survival prediction performance") + print("[INFO] Evaluating baseline survival prediction performance") + def prepare_data(data_object, duration_col, event_col): # Concatenate Data Matrices X = np.concatenate([tensor for tensor in data_object.dat.values()], axis=1) @@ -832,8 +972,10 @@ def prepare_data(data_object, duration_col, event_col): # Prepare Survival Data (Durations and Events) durations = np.array(data_object.ann[duration_col]) events = np.array(data_object.ann[event_col]) - y = np.array([(event, duration) for event, duration in zip(events, durations)], - dtype=[('Event', '?'), ('Time', '= alpha or resulting child would have < min_samples_per_group. @@ -1080,7 +1257,7 @@ def recursive_binary_split_minN(df, score='pred_risk', time='OS.time', event='OS continue left = node[node[score] <= cutoff] - right = node[node[score] > cutoff] + right = node[node[score] > cutoff] # if either resulting child is smaller than required, reject split if len(left) < min_samples_per_group or len(right) < min_samples_per_group: @@ -1092,7 +1269,7 @@ def recursive_binary_split_minN(df, score='pred_risk', time='OS.time', event='OS queue.append(left) queue.append(right) - df['auto_group'] = df.index.map(groups) + df["auto_group"] = df.index.map(groups) return df @@ -1106,65 +1283,73 @@ def plot_hazard_ratios(cox_model): cox_model = cox_model[0] # Extract summary - coef_summary = cox_model.summary[['coef', 'coef lower 95%', 'coef upper 95%', 'p']].copy() - coef_summary.columns = ['coef', 'coef_lower_95', 'coef_upper_95', 'p'] - coef_summary['variable'] = coef_summary.index + coef_summary = cox_model.summary[ + ["coef", "coef lower 95%", "coef upper 95%", "p"] + ].copy() + coef_summary.columns = ["coef", "coef_lower_95", "coef_upper_95", "p"] + coef_summary["variable"] = coef_summary.index # Sort by p-value - coef_summary_sorted = coef_summary.sort_values('p').reset_index(drop=True) + coef_summary_sorted = coef_summary.sort_values("p").reset_index(drop=True) # Add significance stars def significance(p): if p < 0.0001: - return '***' + return "***" elif p < 0.001: - return '**' + return "**" elif p < 0.05: - return '*' + return "*" elif p < 0.1: - return '.' + return "." else: - return '' + return "" - coef_summary_sorted['stars'] = coef_summary_sorted['p'].apply(significance) + coef_summary_sorted["stars"] = coef_summary_sorted["p"].apply(significance) # Reverse the order for top-to-bottom importance - coef_summary_sorted['variable'] = pd.Categorical( - coef_summary_sorted['variable'], - categories=coef_summary_sorted['variable'][::-1], - ordered=True + coef_summary_sorted["variable"] = pd.Categorical( + coef_summary_sorted["variable"], + categories=coef_summary_sorted["variable"][::-1], + ordered=True, ) c_index = cox_model.concordance_index_ # Plot p = ( - ggplot(coef_summary_sorted, aes(x='coef', y='variable')) + ggplot(coef_summary_sorted, aes(x="coef", y="variable")) + geom_errorbarh( - aes(xmin='coef_lower_95', xmax='coef_upper_95'), - height=0.2, color='skyblue' + aes(xmin="coef_lower_95", xmax="coef_upper_95"), + height=0.2, + color="skyblue", + ) + + geom_point(color="skyblue", size=3) + + geom_text(aes(label="stars"), nudge_y=0.1, size=10) + + annotate("vline", xintercept=0, linetype="dashed", color="gray") + + labs( + x="Log Hazard Ratio", + y="", + title=f"Log Hazard Ratios Sorted by P-Value with 95% CI\n Model C-index: {c_index:.2f}", ) - + geom_point(color='skyblue', size=3) - + geom_text(aes(label='stars'), nudge_y=0.1, size=10) - + annotate('vline', xintercept=0, linetype='dashed', color='gray') - + labs(x='Log Hazard Ratio', y='', title=f'Log Hazard Ratios Sorted by P-Value with 95% CI\n Model C-index: {c_index:.2f}') + theme_bw() + theme( axis_text_y=element_text(size=10), axis_text_x=element_text(size=10), - plot_title=element_text(weight='bold'), + plot_title=element_text(weight="bold"), ) ) return p + def build_cox_model( df: pd.DataFrame, duration_col: str, event_col: str, n_splits: int = 5, random_state: int = 42, - eval_time: float | None = None, # single horizon, same units as duration_col + eval_time: float | None = None, # single horizon, same units as duration_col low_variance_threshold: float = 0.01, - cox_penalizer = 0.05, + cox_penalizer=0.05, return_metrics: bool = True, ): """ @@ -1186,14 +1371,18 @@ def remove_low_variance_survival_features(df, duration_col, event_col, threshold for feature in df.drop(columns=[duration_col, event_col]).columns: v1 = df.loc[events, feature].var() v0 = df.loc[~events, feature].var() - if (v1 is not None and v1 < threshold) or (v0 is not None and v0 < threshold): + if (v1 is not None and v1 < threshold) or ( + v0 is not None and v0 < threshold + ): low_var.append(feature) df_f = df.drop(columns=low_var, errors="ignore") if low_var: print("Removed low variance features:", low_var) return df_f - df = remove_low_variance_survival_features(df, duration_col, event_col, low_variance_threshold) + df = remove_low_variance_survival_features( + df, duration_col, event_col, low_variance_threshold + ) metrics = { "cv_cindex_mean": None, "cv_auc_mean": None, @@ -1205,7 +1394,7 @@ def remove_low_variance_survival_features(df, duration_col, event_col, threshold for train_idx, test_idx in kf.split(df): train_df = df.iloc[train_idx] - test_df = df.iloc[test_idx] + test_df = df.iloc[test_idx] model = CoxPHFitter(penalizer=cox_penalizer) model.fit(train_df, duration_col=duration_col, event_col=event_col) @@ -1215,7 +1404,7 @@ def remove_low_variance_survival_features(df, duration_col, event_col, threshold ci = concordance_index( test_df[duration_col].values, -risk_scores, # higher hazard => higher risk - test_df[event_col].astype(int).values + test_df[event_col].astype(int).values, ) c_indices.append(ci) @@ -1223,11 +1412,11 @@ def remove_low_variance_survival_features(df, duration_col, event_col, threshold if eval_time is not None: y_train = Surv.from_arrays( event=train_df[event_col].astype(bool).values, - time=train_df[duration_col].values + time=train_df[duration_col].values, ) y_test = Surv.from_arrays( event=test_df[event_col].astype(bool).values, - time=test_df[duration_col].values + time=test_df[duration_col].values, ) # Only evaluate if the horizon lies strictly inside this test fold's follow-up @@ -1236,7 +1425,10 @@ def remove_low_variance_survival_features(df, duration_col, event_col, threshold eps = 1e-8 if (test_min + eps) < float(eval_time) < (test_max - eps): _, auc_val = cumulative_dynamic_auc( - y_train, y_test, risk_scores, np.asarray([float(eval_time)]) + y_train, + y_test, + risk_scores, + np.asarray([float(eval_time)]), ) auc_val = float(np.atleast_1d(auc_val)[0]) auc_per_fold.append(auc_val) @@ -1264,7 +1456,7 @@ def k_means_clustering(data, k): - kmeans: The fitted KMeans instance, which can be used to access cluster centers and other attributes. """ # Initialize the KMeans model - kmeans = KMeans(n_clusters=k, n_init='auto', random_state=42) + kmeans = KMeans(n_clusters=k, n_init="auto", random_state=42) # Fit the model to the data kmeans.fit(data) @@ -1274,6 +1466,7 @@ def k_means_clustering(data, k): return cluster_labels, kmeans + def louvain_clustering(X, threshold=None, k=None): """ Create a graph from pairwise distances within X. You can define a threshold to connect edges @@ -1293,11 +1486,11 @@ def louvain_clustering(X, threshold=None, k=None): for j in range(i + 1, distances.shape[0]): # If a threshold is defined, use it to create edges if threshold is not None and distances[i, j] < threshold: - G.add_edge(i, j, weight=1/distances[i, j]) + G.add_edge(i, j, weight=1 / distances[i, j]) # If k is defined, add an edge if j is one of i's k-nearest neighbors elif k is not None: - if np.argsort(distances[i])[:k + 1].__contains__(j): - G.add_edge(i, j, weight=1/distances[i, j]) + if np.argsort(distances[i])[: k + 1].__contains__(j): + G.add_edge(i, j, weight=1 / distances[i, j]) partition = community_louvain.best_partition(G) cluster_labels = np.full(len(X), np.nan, dtype=float) @@ -1331,33 +1524,39 @@ def get_optimal_clusters(data, min_k=2, max_k=10): cluster_labels_dict = {} # To store cluster labels for each k for k in range(min_k, max_k + 1): - kmeans = KMeans(n_clusters=k, n_init = 'auto', random_state=42) + kmeans = KMeans(n_clusters=k, n_init="auto", random_state=42) cluster_labels = kmeans.fit_predict(data) silhouette_avg = silhouette_score(data, cluster_labels) silhouette_scores.append((k, silhouette_avg)) cluster_labels_dict[k] = cluster_labels # Store cluster labels - #print(f"Number of clusters: {k}, Silhouette Score: {silhouette_avg:.4f}") + # print(f"Number of clusters: {k}, Silhouette Score: {silhouette_avg:.4f}") # Convert silhouette scores to DataFrame for easier handling and visualization - silhouette_scores_df = pd.DataFrame(silhouette_scores, columns=['k', 'silhouette_score']) + silhouette_scores_df = pd.DataFrame( + silhouette_scores, columns=["k", "silhouette_score"] + ) # Find the optimal k (number of clusters) with the highest silhouette score - optimal_k = silhouette_scores_df.loc[silhouette_scores_df['silhouette_score'].idxmax()]['k'] + optimal_k = silhouette_scores_df.loc[ + silhouette_scores_df["silhouette_score"].idxmax() + ]["k"] # Retrieve the cluster labels for the optimal k optimal_cluster_labels = cluster_labels_dict[optimal_k] return optimal_cluster_labels, optimal_k, silhouette_scores_df + # compute adjusted rand index; adjusted mutual information for two sets of paired labels def compute_ami_ari(labels1, labels2): - def convert_nan (labels): - return ['unavailable' if pd.isna(x) else x for x in labels] + def convert_nan(labels): + return ["unavailable" if pd.isna(x) else x for x in labels] + labels1 = convert_nan(labels1) labels2 = convert_nan(labels2) ami = adjusted_mutual_info_score(labels1, labels2) ari = adjusted_rand_score(labels1, labels2) - return {'ami': ami, 'ari': ari} + return {"ami": ami, "ari": ari} def plot_label_concordance_heatmap(labels1, labels2, figsize=(12, 10)): @@ -1369,18 +1568,23 @@ def plot_label_concordance_heatmap(labels1, labels2, figsize=(12, 10)): - labels2: The second set of labels. """ # Compute the cross-tabulation - ct = pd.crosstab(pd.Series(labels1, name='Labels Set 1'), pd.Series(labels2, name='Labels Set 2')) + ct = pd.crosstab( + pd.Series(labels1, name="Labels Set 1"), + pd.Series(labels2, name="Labels Set 2"), + ) # Normalize the cross-tabulation matrix column-wise ct_normalized = ct.div(ct.sum(axis=1), axis=0) # Plot the heatmap - plt.figure(figsize = figsize) - sns.heatmap(ct_normalized, annot=True,cmap='viridis', linewidths=.5)# col_cluster=False) - plt.title('Concordance between label groups') + plt.figure(figsize=figsize) + sns.heatmap( + ct_normalized, annot=True, cmap="viridis", linewidths=0.5 + ) # col_cluster=False) + plt.title("Concordance between label groups") return plt.gcf() -def scale_and_standardize_by_labels(data_matrix, labels): +def scale_and_standardize_by_labels(data_matrix, labels): """ Scale and standardize data_matrix by factor labels. Data is split by factors and each subset is scaled/standardized. @@ -1417,47 +1621,63 @@ def scale_and_standardize_by_labels(data_matrix, labels): return scaled_data_matrix + # df: annotation data frame ('clin.csv') # given a pandas data frame, go through each column and find out if the column is numeric or categorical def get_variable_types(df): # Select only the categorical columns - df_categorical = df.select_dtypes(include=['object', 'category']) - variable_types = {col: 'categorical' for col in df_categorical.columns} - variable_types.update({col: 'numerical' for col in df.select_dtypes(exclude=['object', 'category']).columns}) + df_categorical = df.select_dtypes(include=["object", "category"]) + variable_types = {col: "categorical" for col in df_categorical.columns} + variable_types.update( + { + col: "numerical" + for col in df.select_dtypes(exclude=["object", "category"]).columns + } + ) return variable_types + def create_covariate_matrix(covariates, variable_types, ann): """ Convert clinical variables used as covariates into a covariate matrix as a Pandas DataFrame. Missing values in numerical variables are imputed using the median. Args: - covariates (list of str): List of variable names that must exist in the "clin.csv". - variable_types (dict): Dictionary mapping variable names to their types ('categorical' or 'numerical'). + covariates (list of str): List of variable names that must exist in the + "clin.csv". + variable_types (dict): Dictionary mapping variable names to their types + ('categorical' or 'numerical'). ann (pd.DataFrame): Annotation DataFrame containing batch variable values. Returns: - pd.DataFrame: A covariate matrix DataFrame where categorical variables are one-hot-encoded as 0/1 and numerical variables are imputed, - with features as rows and samples as columns. + pd.DataFrame: A covariate matrix DataFrame where categorical variables are + one-hot-encoded as 0/1 and numerical variables are imputed, with + features as rows and samples as columns. """ covariate_features = [] feature_names = [] for var in covariates: - if variable_types.get(var) == 'categorical': + if variable_types.get(var) == "categorical": # One-hot-encode categorical variables with 0/1 encoding one_hot = pd.get_dummies(ann[var], prefix=var).astype(int) covariate_features.append(one_hot.T) # Transpose to make features rows feature_names.extend(one_hot.columns.tolist()) - elif variable_types.get(var) == 'numerical': + elif variable_types.get(var) == "numerical": # Handle numerical variables with missing values numerical_data = ann[[var]].copy() # Impute missing values using the median and assign back - numerical_data[var] = numerical_data[var].fillna(numerical_data[var].median()) - covariate_features.append(numerical_data.T) # Transpose to make features rows + numerical_data[var] = numerical_data[var].fillna( + numerical_data[var].median() + ) + covariate_features.append( + numerical_data.T + ) # Transpose to make features rows feature_names.append(var) else: - raise ValueError(f"Unknown variable type for {var}: {variable_types.get(var)}") + raise ValueError( + f"Unknown variable type for {var}: {variable_types.get(var)}" + ) # Concatenate all covariate features into a single DataFrame covariate_matrix = pd.concat(covariate_features, axis=0) @@ -1468,23 +1688,31 @@ def create_covariate_matrix(covariates, variable_types, ann): return covariate_matrix -def generate_synthetic_batches (n_samples_per_batch = 150, n_features = 50): + +def generate_synthetic_batches(n_samples_per_batch=150, n_features=50): # Generate batch 1 data (mean centered at 0, standard deviation 1) - batch1_data = np.random.normal(loc=0.0, scale=1.0, size=(n_samples_per_batch, n_features)) + batch1_data = np.random.normal( + loc=0.0, scale=1.0, size=(n_samples_per_batch, n_features) + ) # Generate batch 2 data (mean shifted by +2, standard deviation 1.5) - batch2_data = np.random.normal(loc=2.0, scale=1.5, size=(n_samples_per_batch, n_features)) + batch2_data = np.random.normal( + loc=2.0, scale=1.5, size=(n_samples_per_batch, n_features) + ) # Combine into a single dataset combined_data = np.vstack([batch1_data, batch2_data]) - batch_labels = np.array([0] * n_samples_per_batch + [1] * n_samples_per_batch) # Batch labels + batch_labels = np.array( + [0] * n_samples_per_batch + [1] * n_samples_per_batch + ) # Batch labels # Convert to Pandas DataFrame feature_columns = [f"feature_{i+1}" for i in range(n_features)] synthetic_data = pd.DataFrame(combined_data, columns=feature_columns) return synthetic_data, batch_labels -def optimal_transport_align(embeddings, batch_labels, standardize_by_labels = False): + +def optimal_transport_align(embeddings, batch_labels, standardize_by_labels=False): """ Align embeddings from two batches using Optimal Transport, preserving the order of samples. @@ -1493,8 +1721,10 @@ def optimal_transport_align(embeddings, batch_labels, standardize_by_labels = Fa - batch_labels (np.ndarray or pd.Series): Batch labels corresponding to the rows of embeddings. Returns: - - aligned_embeddings (pd.DataFrame): A DataFrame containing the aligned embeddings for all samples, with original indices preserved. - - aligned_batch_labels (pd.Series): A Series containing the corresponding batch labels for the aligned embeddings. + - aligned_embeddings (pd.DataFrame): A DataFrame containing the aligned + embeddings for all samples, with original indices preserved. + - aligned_batch_labels (pd.Series): A Series containing the corresponding + batch labels for the aligned embeddings. """ # Ensure batch labels are a NumPy array batch_labels_np = np.array(batch_labels) @@ -1512,7 +1742,7 @@ def optimal_transport_align(embeddings, batch_labels, standardize_by_labels = Fa batch2_embeddings = embeddings.iloc[batch2_indices].to_numpy() # Compute the cost matrix (e.g., Euclidean distances) - cost_matrix = ot.dist(batch1_embeddings, batch2_embeddings, metric='euclidean') + cost_matrix = ot.dist(batch1_embeddings, batch2_embeddings, metric="euclidean") # Solve the optimal transport problem n_samples_1 = batch1_embeddings.shape[0] @@ -1534,19 +1764,32 @@ def optimal_transport_align(embeddings, batch_labels, standardize_by_labels = Fa scaler1 = StandardScaler() scaler2 = StandardScaler() - aligned_embeddings[batch1_indices] = scaler1.fit_transform(aligned_embeddings[batch1_indices]) - aligned_embeddings[batch2_indices] = scaler2.fit_transform(aligned_embeddings[batch2_indices]) + aligned_embeddings[batch1_indices] = scaler1.fit_transform( + aligned_embeddings[batch1_indices] + ) + aligned_embeddings[batch2_indices] = scaler2.fit_transform( + aligned_embeddings[batch2_indices] + ) # Convert back to pandas DataFrame and Series, preserving indices - aligned_embeddings_df = pd.DataFrame(aligned_embeddings, columns=embeddings.columns, index=embeddings.index) - aligned_batch_labels = pd.Series(batch_labels, index=embeddings.index, name="batch_labels") + aligned_embeddings_df = pd.DataFrame( + aligned_embeddings, columns=embeddings.columns, index=embeddings.index + ) + aligned_batch_labels = pd.Series( + batch_labels, index=embeddings.index, name="batch_labels" + ) return aligned_embeddings_df, aligned_batch_labels -from sklearn.neighbors import NearestNeighbors - -def reciprocal_pca_mnn(embeddings, batch_labels, n_components=10, n_neighbors=5, standardize_by_labels=False, random_state=None): +def reciprocal_pca_mnn( + embeddings, + batch_labels, + n_components=10, + n_neighbors=5, + standardize_by_labels=False, + random_state=None, +): """ Align embeddings from two batches using Reciprocal PCA (rPCA) and Mutual Nearest Neighbors (MNN). @@ -1559,8 +1802,10 @@ def reciprocal_pca_mnn(embeddings, batch_labels, n_components=10, n_neighbors=5, - random_state (int, optional): Random seed for reproducibility. Returns: - - aligned_embeddings (pd.DataFrame): A DataFrame containing the aligned embeddings for all samples, with original indices preserved. - - aligned_batch_labels (pd.Series): A Series containing the corresponding batch labels for the aligned embeddings. + - aligned_embeddings (pd.DataFrame): A DataFrame containing the aligned + embeddings for all samples, with original indices preserved. + - aligned_batch_labels (pd.Series): A Series containing the corresponding + batch labels for the aligned embeddings. """ # Ensure batch labels are a NumPy array batch_labels_np = np.array(batch_labels) @@ -1579,8 +1824,12 @@ def reciprocal_pca_mnn(embeddings, batch_labels, n_components=10, n_neighbors=5, # Standardize embeddings separately for each batch if required if standardize_by_labels: - batch1_embeddings = (batch1_embeddings - batch1_embeddings.mean(axis=0)) / batch1_embeddings.std(axis=0) - batch2_embeddings = (batch2_embeddings - batch2_embeddings.mean(axis=0)) / batch2_embeddings.std(axis=0) + batch1_embeddings = ( + batch1_embeddings - batch1_embeddings.mean(axis=0) + ) / batch1_embeddings.std(axis=0) + batch2_embeddings = ( + batch2_embeddings - batch2_embeddings.mean(axis=0) + ) / batch2_embeddings.std(axis=0) # Perform PCA on both batches pca1 = PCA(n_components=n_components, random_state=random_state) @@ -1628,18 +1877,18 @@ def reciprocal_pca_mnn(embeddings, batch_labels, n_components=10, n_neighbors=5, aligned_embeddings[batch2_indices] = aligned_batch2 # Convert back to pandas DataFrame and Series, preserving indices - aligned_embeddings_df = pd.DataFrame(aligned_embeddings, - columns=[f"rPCA_{i+1}" for i in range(n_components)], - index=embeddings.index) - aligned_batch_labels = pd.Series(batch_labels, index=embeddings.index, name="batch_labels") + aligned_embeddings_df = pd.DataFrame( + aligned_embeddings, + columns=[f"rPCA_{i+1}" for i in range(n_components)], + index=embeddings.index, + ) + aligned_batch_labels = pd.Series( + batch_labels, index=embeddings.index, name="batch_labels" + ) return aligned_embeddings_df, aligned_batch_labels -import tarfile -import requests -from io import StringIO - class CBioPortalData: def __init__(self, study_id, base_url="https://datahub.assets.cbioportal.org"): self.base_url = base_url @@ -1647,7 +1896,7 @@ def __init__(self, study_id, base_url="https://datahub.assets.cbioportal.org"): self.data_files = None self.data = None - def download_study_archive(self, force=False, timeout=60): + def download_study_archive(self, force=False, timeout=120): url = f"{self.base_url}/{self.study_id}.tar.gz" dest_file = f"{self.study_id}.tar.gz" @@ -1673,7 +1922,9 @@ def extract_archive(self, archive_path): with tarfile.open(archive_path, "r:gz") as tar: tar.extractall() - self.data_files = [f for f in os.listdir(base) if f.startswith("data_") and f.endswith(".txt")] + self.data_files = [ + f for f in os.listdir(base) if f.startswith("data_") and f.endswith(".txt") + ] return base def read_data(self, files=None): @@ -1684,12 +1935,12 @@ def read_data(self, files=None): for datatype, file in files.items(): print(f"Importing {file}...") file_path = os.path.join(self.study_id, file) - df = pd.read_csv(file_path, sep='\t', comment='#', low_memory=False) + df = pd.read_csv(file_path, sep="\t", comment="#", low_memory=False) - if 'mutations' in file: + if "mutations" in file: print(f"Binarizing and converting {file} to matrix...") df = self.binarize_mutations(df) - elif 'clinical' not in file and 'drug_treatment' not in file: + elif "clinical" not in file and "drug_treatment" not in file: print(f"Converting {file} to matrix...") df = self.process_data(df) @@ -1697,12 +1948,12 @@ def read_data(self, files=None): return data def process_data(self, df): - if 'Hugo_Symbol' in df.columns and 'Entrez_Gene_Id' in df.columns: - df = df.drop(columns=['Entrez_Gene_Id'], errors='ignore') + if "Hugo_Symbol" in df.columns and "Entrez_Gene_Id" in df.columns: + df = df.drop(columns=["Entrez_Gene_Id"], errors="ignore") - if 'Hugo_Symbol' in df.columns: - df = df.drop_duplicates(subset=['Hugo_Symbol']) - df.set_index('Hugo_Symbol', inplace=True) + if "Hugo_Symbol" in df.columns: + df = df.drop_duplicates(subset=["Hugo_Symbol"]) + df.set_index("Hugo_Symbol", inplace=True) return df @@ -1711,10 +1962,18 @@ def binarize_mutations(self, df): for col in required_cols: if col not in df.columns: - raise ValueError(f"Can't map mutations to sample IDs. Column {col} not found.") + raise ValueError( + f"Can't map mutations to sample IDs. Column {col} not found." + ) - mutation_counts = df.groupby(["Hugo_Symbol", "Tumor_Sample_Barcode"]).size().reset_index(name='count') - mutation_matrix = mutation_counts.pivot(index='Hugo_Symbol', columns='Tumor_Sample_Barcode', values='count').fillna(0) + mutation_counts = ( + df.groupby(["Hugo_Symbol", "Tumor_Sample_Barcode"]) + .size() + .reset_index(name="count") + ) + mutation_matrix = mutation_counts.pivot( + index="Hugo_Symbol", columns="Tumor_Sample_Barcode", values="count" + ).fillna(0) mutation_matrix[mutation_matrix > 0] = 1 return mutation_matrix @@ -1725,27 +1984,31 @@ def print_data_files(self): def get_cbioportal_data(self, study_id, files=None): archive_path = self.download_study_archive() - study_dir = self.extract_archive(archive_path) + self.extract_archive(archive_path) if files is None: self.print_data_files() - print("\n\nPlease select a list of files to import. Example:\n get_cbioportal_data('study_id', files={'mut': 'data_mutations.txt', 'clin': 'data_clinical_patient.txt'})") + print( + "\n\nPlease select a list of files to import. Example:\n " + "get_cbioportal_data('study_id', files={'mut': 'data_mutations.txt', " + "'clin': 'data_clinical_patient.txt'})" + ) return data = self.read_data(files) - if 'clin' in files: - clin = data['clin'] + if "clin" in files: + clin = data["clin"] clin = clin.drop_duplicates(subset=clin.columns[0]) clin.set_index(clin.columns[0], inplace=True) - data['clin'] = clin + data["clin"] = clin print({x: data[x].shape for x in data.keys()}) self.data = data def split_data(self, samples=None, ratio=0.7): if samples is None: - samples = self.data['clin'].index.tolist() + samples = self.data["clin"].index.tolist() train_samples = list(pd.Series(samples).sample(frac=ratio, random_state=42)) test_samples = list(set(samples) - set(train_samples)) @@ -1754,10 +2017,18 @@ def split_data(self, samples=None, ratio=0.7): test_data = {} for key, df in self.data.items(): - train_data[key] = df.loc[df.index.intersection(train_samples)] if key == 'clin' else df.loc[:, df.columns.intersection(train_samples)] - test_data[key] = df.loc[df.index.intersection(test_samples)] if key == 'clin' else df.loc[:, df.columns.intersection(test_samples)] + train_data[key] = ( + df.loc[df.index.intersection(train_samples)] + if key == "clin" + else df.loc[:, df.columns.intersection(train_samples)] + ) + test_data[key] = ( + df.loc[df.index.intersection(test_samples)] + if key == "clin" + else df.loc[:, df.columns.intersection(test_samples)] + ) - return {'train': train_data, 'test': test_data} + return {"train": train_data, "test": test_data} def print_dataset(self, dataset, outdir): if not os.path.exists(outdir): @@ -1769,8 +2040,7 @@ def print_dataset(self, dataset, outdir): os.makedirs(split_dir) for file, df in data.items(): - df.to_csv(os.path.join(split_dir, f"{file}.csv"), sep=',') - + df.to_csv(os.path.join(split_dir, f"{file}.csv"), sep=",") def compute_correlation_loss(embeddings, batch_labels): @@ -1778,7 +2048,9 @@ def compute_correlation_loss(embeddings, batch_labels): batch_labels = batch_labels.float() # Normalize embeddings - embeddings = (embeddings - embeddings.mean(dim=0, keepdim=True)) / (embeddings.std(dim=0, keepdim=True) + 1e-8) + embeddings = (embeddings - embeddings.mean(dim=0, keepdim=True)) / ( + embeddings.std(dim=0, keepdim=True) + 1e-8 + ) # Normalize batch labels batch_labels = (batch_labels - batch_labels.mean()) / (batch_labels.std() + 1e-8) @@ -1793,7 +2065,7 @@ def compute_correlation_loss(embeddings, batch_labels): loss = torch.sum(torch.abs(covariance)) return loss -# from geomloss import SamplesLoss + def compute_transport_cost(embeddings, batch_labels, blur=0.5): """ Compute a transport cost using Sinkhorn loss to align embeddings between batches. @@ -1814,7 +2086,14 @@ def compute_transport_cost(embeddings, batch_labels, blur=0.5): batch2_embeddings = embeddings[batch_labels == 1] if batch1_embeddings.size(0) == 0 or batch2_embeddings.size(0) == 0: - raise ValueError("Both batches must have at least one sample for transport cost computation.") + raise ValueError( + "Both batches must have at least one sample for transport cost computation." + ) + + if SamplesLoss is None: + raise ImportError( + "geomloss is required for compute_transport_cost. Install it with: pip install geomloss" + ) # Initialize the Sinkhorn loss function loss_fn = SamplesLoss("sinkhorn", blur=blur) @@ -1855,29 +2134,33 @@ def get_optimal_device(device_preference=None): - device_type: String indicating the device type for compatibility """ if device_preference is None: - device_preference = 'auto' + device_preference = "auto" # If specific device is requested, validate and return it - if device_preference == 'cuda': + if device_preference == "cuda": if torch.cuda.is_available(): - return 'cuda', 'gpu' + return "cuda", "gpu" else: - warnings.warn("CUDA requested but not available. Falling back to auto-detection.") - elif device_preference == 'mps': + warnings.warn( + "CUDA requested but not available. Falling back to auto-detection." + ) + elif device_preference == "mps": if torch.backends.mps.is_available(): - return 'mps', 'mps' + return "mps", "mps" else: - warnings.warn("MPS requested but not available. Falling back to auto-detection.") - elif device_preference == 'cpu': - return 'cpu', 'cpu' + warnings.warn( + "MPS requested but not available. Falling back to auto-detection." + ) + elif device_preference == "cpu": + return "cpu", "cpu" # Auto-detection logic (priority: CUDA > MPS > CPU) if torch.cuda.is_available(): - return 'cuda', 'gpu' + return "cuda", "gpu" elif torch.backends.mps.is_available(): - return 'mps', 'mps' + return "mps", "mps" else: - return 'cpu', 'cpu' + return "cpu", "cpu" def get_device_memory_info(device_str): @@ -1890,30 +2173,30 @@ def get_device_memory_info(device_str): Returns: dict: Memory information dictionary """ - if device_str == 'cuda' and torch.cuda.is_available(): + if device_str == "cuda" and torch.cuda.is_available(): return { - 'allocated': torch.cuda.memory_allocated() / (1024**2), # MB - 'reserved': torch.cuda.memory_reserved() / (1024**2), # MB - 'max_allocated': torch.cuda.max_memory_allocated() / (1024**2), # MB - 'device_name': torch.cuda.get_device_name(0), - 'device_count': torch.cuda.device_count() + "allocated": torch.cuda.memory_allocated() / (1024**2), # MB + "reserved": torch.cuda.memory_reserved() / (1024**2), # MB + "max_allocated": torch.cuda.max_memory_allocated() / (1024**2), # MB + "device_name": torch.cuda.get_device_name(0), + "device_count": torch.cuda.device_count(), } - elif device_str == 'mps' and torch.backends.mps.is_available(): + elif device_str == "mps" and torch.backends.mps.is_available(): # MPS doesn't have the same detailed memory tracking as CUDA return { - 'allocated': 'N/A (MPS)', - 'reserved': 'N/A (MPS)', - 'max_allocated': 'N/A (MPS)', - 'device_name': 'Apple Metal Performance Shaders', - 'device_count': 1 + "allocated": "N/A (MPS)", + "reserved": "N/A (MPS)", + "max_allocated": "N/A (MPS)", + "device_name": "Apple Metal Performance Shaders", + "device_count": 1, } else: return { - 'allocated': 'N/A (CPU)', - 'reserved': 'N/A (CPU)', - 'max_allocated': 'N/A (CPU)', - 'device_name': 'CPU', - 'device_count': 1 + "allocated": "N/A (CPU)", + "reserved": "N/A (CPU)", + "max_allocated": "N/A (CPU)", + "device_name": "CPU", + "device_count": 1, } @@ -1927,20 +2210,20 @@ def create_device_from_string(device_str): Returns: torch.device: PyTorch device object """ - if device_str in ['gpu', 'auto']: + if device_str in ["gpu", "auto"]: optimal_device, _ = get_optimal_device() return torch.device(optimal_device) - elif device_str == 'mps': + elif device_str == "mps": if torch.backends.mps.is_available(): - return torch.device('mps') + return torch.device("mps") else: warnings.warn("MPS not available, falling back to CPU") - return torch.device('cpu') - elif device_str == 'cuda': + return torch.device("cpu") + elif device_str == "cuda": if torch.cuda.is_available(): - return torch.device('cuda') + return torch.device("cuda") else: warnings.warn("CUDA not available, falling back to CPU") - return torch.device('cpu') + return torch.device("cpu") else: - return torch.device('cpu') + return torch.device("cpu") diff --git a/tests/test_mps_device.py b/tests/test_mps_device.py index 231822dd..ca431cf4 100644 --- a/tests/test_mps_device.py +++ b/tests/test_mps_device.py @@ -1,26 +1,27 @@ -import torch import pytest +import torch + from flexynesis.utils import get_optimal_device, to_device_safe def test_mps_device_detection(): """Test MPS device detection via Flexynesis.""" - device_str, device_type = get_optimal_device('mps') - + device_str, device_type = get_optimal_device("mps") + if not torch.backends.mps.is_available(): pytest.skip("MPS device not available. Skipping test.") - - assert device_type == 'mps', f"Expected device type 'mps', got '{device_type}'" - assert device_str == 'mps', f"Expected device string 'mps', got '{device_str}'" + + assert device_type == "mps", f"Expected device type 'mps', got '{device_type}'" + assert device_str == "mps", f"Expected device string 'mps', got '{device_str}'" def test_mps_tensor_operations(): """Test MPS tensor operations via Flexynesis safe transfer functions.""" - device_str, device_type = get_optimal_device('mps') - - if device_str != 'mps' or not torch.backends.mps.is_available(): + device_str, device_type = get_optimal_device("mps") + + if device_str != "mps" or not torch.backends.mps.is_available(): pytest.skip("MPS device not available. Skipping test.") - + device = torch.device(device_str) # Test basic tensor operations using safe transfer @@ -28,22 +29,26 @@ def test_mps_tensor_operations(): y = to_device_safe(torch.randn(50, 25), device) result = torch.mm(x, y) - assert result.shape == (100, 25), f"Unexpected result shape: {result.shape}" - assert result.device.type == "mps", f"Result not on MPS device: {result.device.type}" + assert result.shape == ( + 100, + 25, + ), f"Unexpected result shape: {result.shape}" + assert ( + result.device.type == "mps" + ), f"Result not on MPS device: {result.device.type}" def test_mps_memory_allocation(): """Test MPS memory allocation tracking.""" - device_str, device_type = get_optimal_device('mps') - - if device_str != 'mps' or not torch.backends.mps.is_available(): + device_str, device_type = get_optimal_device("mps") + + if device_str != "mps" or not torch.backends.mps.is_available(): pytest.skip("MPS device not available. Skipping test.") - - device = torch.device(device_str) + device = torch.device(device_str) # Test memory tracking memory_before = torch.mps.current_allocated_memory() - large_tensor = to_device_safe(torch.randn(1000, 1000), device) + large_tensor = to_device_safe(torch.randn(1000, 1000), device) # noqa: F841 memory_after = torch.mps.current_allocated_memory() assert memory_after > memory_before, "Memory did not increase after allocation." @@ -51,11 +56,11 @@ def test_mps_memory_allocation(): def test_float64_to_float32_conversion(): """Test automatic float64 to float32 conversion for MPS compatibility.""" - device_str, device_type = get_optimal_device('mps') - - if device_str != 'mps' or not torch.backends.mps.is_available(): + device_str, device_type = get_optimal_device("mps") + + if device_str != "mps" or not torch.backends.mps.is_available(): pytest.skip("MPS device not available. Skipping test.") - + device = torch.device(device_str) # Test float64 tensor conversion diff --git a/tests/unit/test_smoke.py b/tests/unit/test_smoke.py index 701b7349..0bcee2ab 100644 --- a/tests/unit/test_smoke.py +++ b/tests/unit/test_smoke.py @@ -1,2 +1,2 @@ def test_smoke(): - import flexynesis + import flexynesis # noqa: F401