diff --git a/.gitignore b/.gitignore index b7434a582..002ef920f 100755 --- a/.gitignore +++ b/.gitignore @@ -189,6 +189,11 @@ notebooks/tmp /tutorials/lightning_logs/ /tutorials/datasets/ +# scripts +scripts/hopse_plotting/csvs +scripts/hopse_plotting/plots +scripts/hopse_plotting/tables + # wandb wandb/ result_BREC/ diff --git a/configs/dataset/graph/BBB_Martins.yaml b/configs/dataset/graph/BBB_Martins.yaml new file mode 100644 index 000000000..645beb9b7 --- /dev/null +++ b/configs/dataset/graph/BBB_Martins.yaml @@ -0,0 +1,36 @@ +# Dataset loader config +loader: + _target_: topobench.data.loaders.ADMEDatasetLoader + parameters: + data_domain: graph + data_type: ADME + data_name: BBB_Martins + data_dir: ${paths.data_dir}/${dataset.loader.parameters.data_domain}/${dataset.loader.parameters.data_type} + +# Dataset parameters +parameters: + num_features: + - 9 # OGB atom features + - 3 # OGB edge features + num_classes: 2 + task: classification + loss_type: cross_entropy + monitor_metric: accuracy + task_level: graph + max_dim_if_lifted: 3 + preserve_edge_attr_if_lifted: ${set_preserve_edge_attr:${model.model_name},True} + +# Splits - using fixed scaffold split from TDC +split_params: + learning_setting: inductive + data_split_dir: ${paths.data_dir}/data_splits/${dataset.loader.parameters.data_name} + data_seed: 0 + split_type: fixed # TDC provides predefined scaffold splits + k: 10 + train_prop: 0.5 + +# Dataloader parameters +dataloader_params: + batch_size: 64 + num_workers: 0 + pin_memory: False diff --git a/configs/dataset/graph/CYP3A4_Veith.yaml b/configs/dataset/graph/CYP3A4_Veith.yaml new file mode 100644 index 000000000..24fb40826 --- /dev/null +++ b/configs/dataset/graph/CYP3A4_Veith.yaml @@ -0,0 +1,36 @@ +# Dataset loader config +loader: + _target_: topobench.data.loaders.ADMEDatasetLoader + parameters: + data_domain: graph + data_type: ADME + data_name: CYP3A4_Veith + data_dir: ${paths.data_dir}/${dataset.loader.parameters.data_domain}/${dataset.loader.parameters.data_type} + +# Dataset parameters +parameters: + num_features: + - 9 # OGB atom features + - 3 # OGB edge features + num_classes: 2 + task: classification + loss_type: cross_entropy + monitor_metric: accuracy + task_level: graph + max_dim_if_lifted: 3 + preserve_edge_attr_if_lifted: ${set_preserve_edge_attr:${model.model_name},True} + +# Splits - using fixed scaffold split from TDC +split_params: + learning_setting: inductive + data_split_dir: ${paths.data_dir}/data_splits/${dataset.loader.parameters.data_name} + data_seed: 0 + split_type: fixed # TDC provides predefined scaffold splits + k: 10 + train_prop: 0.5 + +# Dataloader parameters +dataloader_params: + batch_size: 64 + num_workers: 0 + pin_memory: False diff --git a/configs/dataset/graph/Caco2_Wang.yaml b/configs/dataset/graph/Caco2_Wang.yaml new file mode 100644 index 000000000..b5c5b5feb --- /dev/null +++ b/configs/dataset/graph/Caco2_Wang.yaml @@ -0,0 +1,36 @@ +# Dataset loader config +loader: + _target_: topobench.data.loaders.ADMEDatasetLoader + parameters: + data_domain: graph + data_type: ADME + data_name: Caco2_Wang + data_dir: ${paths.data_dir}/${dataset.loader.parameters.data_domain}/${dataset.loader.parameters.data_type} + +# Dataset parameters +parameters: + num_features: + - 9 # OGB atom features + - 3 # OGB edge features + num_classes: 1 # Regression task + task: regression + loss_type: mse + monitor_metric: mae + task_level: graph + max_dim_if_lifted: 3 + preserve_edge_attr_if_lifted: ${set_preserve_edge_attr:${model.model_name},True} + +# Splits - using fixed scaffold split from TDC +split_params: + learning_setting: inductive + data_split_dir: ${paths.data_dir}/data_splits/${dataset.loader.parameters.data_name} + data_seed: 0 + split_type: fixed # TDC provides predefined scaffold splits + k: 10 + train_prop: 0.5 + +# Dataloader parameters +dataloader_params: + batch_size: 64 + num_workers: 0 + pin_memory: False \ No newline at end of file diff --git a/configs/dataset/graph/Clearance_Hepatocyte_AZ.yaml b/configs/dataset/graph/Clearance_Hepatocyte_AZ.yaml new file mode 100644 index 000000000..7081a326c --- /dev/null +++ b/configs/dataset/graph/Clearance_Hepatocyte_AZ.yaml @@ -0,0 +1,36 @@ +# Dataset loader config +loader: + _target_: topobench.data.loaders.ADMEDatasetLoader + parameters: + data_domain: graph + data_type: ADME + data_name: Clearance_Hepatocyte_AZ + data_dir: ${paths.data_dir}/${dataset.loader.parameters.data_domain}/${dataset.loader.parameters.data_type} + +# Dataset parameters +parameters: + num_features: + - 9 # OGB atom features + - 3 # OGB edge features + num_classes: 1 # Regression task + task: regression + loss_type: mse + monitor_metric: mae + task_level: graph + max_dim_if_lifted: 3 + preserve_edge_attr_if_lifted: ${set_preserve_edge_attr:${model.model_name},True} + +# Splits - using fixed scaffold split from TDC +split_params: + learning_setting: inductive + data_split_dir: ${paths.data_dir}/data_splits/${dataset.loader.parameters.data_name} + data_seed: 0 + split_type: fixed # TDC provides predefined scaffold splits + k: 10 + train_prop: 0.5 + +# Dataloader parameters +dataloader_params: + batch_size: 64 + num_workers: 0 + pin_memory: False \ No newline at end of file diff --git a/configs/hydra/default.yaml b/configs/hydra/default.yaml index d09ee508d..69c692c78 100644 --- a/configs/hydra/default.yaml +++ b/configs/hydra/default.yaml @@ -7,9 +7,9 @@ defaults: # output directory, generated dynamically on each run run: - dir: ${paths.log_dir}/${task_name}/runs/${now:%Y-%m-%d}_${now:%H-%M-%S} + dir: ${paths.log_dir}/${task_name}/runs/${now:%Y-%m-%d}_${now:%H-%M-%S}_${pid:} sweep: - dir: ${paths.log_dir}/${task_name}/multiruns/${now:%Y-%m-%d}_${now:%H-%M-%S} + dir: ${paths.log_dir}/${task_name}/multiruns/${now:%Y-%m-%d}_${now:%H-%M-%S}_${pid:} subdir: ${hydra.job.num} job_logging: diff --git a/configs/run.yaml b/configs/run.yaml index 38f3eab97..eaf1d4b92 100755 --- a/configs/run.yaml +++ b/configs/run.yaml @@ -41,6 +41,7 @@ train: True # evaluate on test set, using best model weights achieved during training # lightning chooses best weights based on the metric specified in checkpoint callback test: True +delete_checkpoint_after_test: False # simply provide checkpoint path to resume training ckpt_path: null diff --git a/configs/transforms/data_manipulations/hopse_ps_information.yaml b/configs/transforms/data_manipulations/hopse_ps_information.yaml index 0a68da432..233ae02d5 100644 --- a/configs/transforms/data_manipulations/hopse_ps_information.yaml +++ b/configs/transforms/data_manipulations/hopse_ps_information.yaml @@ -42,6 +42,7 @@ parameters: include_eigenvalues: false include_first: false concat_to_x: false + device: 'cpu' # Force CPU for eigen computations RWSE: max_pe_dim: 10 concat_to_x: false diff --git a/configs/transforms/data_manipulations/precompute_khop_features.yaml b/configs/transforms/data_manipulations/precompute_khop_features.yaml index e249ed1ed..432cc65db 100644 --- a/configs/transforms/data_manipulations/precompute_khop_features.yaml +++ b/configs/transforms/data_manipulations/precompute_khop_features.yaml @@ -3,7 +3,7 @@ transform_type: "data manipulation" max_hop: 1 use_initial_features: True complex_dim: ${oc.select:dataset.parameters.max_dim_if_lifted,3} -in_channels: ${infer_in_khop_feature_dim:${model.feature_encoder.dataset_in_channels},${.max_hop}} +in_channels: ${infer_in_khop_feature_dim:${model.feature_encoder.dataset_in_channels},${.max_hop},${.complex_dim}} max_rank: 2 # in_features: ${infer_in_sann_khop_feature_dim:${model},${3}} diff --git a/configs/transforms/model_dataset_defaults/gps_PROTEINS.yaml b/configs/transforms/model_dataset_defaults/gps_PROTEINS.yaml new file mode 100644 index 000000000..d6d883370 --- /dev/null +++ b/configs/transforms/model_dataset_defaults/gps_PROTEINS.yaml @@ -0,0 +1,4 @@ +defaults: + - data_manipulations: identity # PROTEINS dataset needs identity transform to avoid adding random float feature to feature matrix + - data_manipulations@CombinedPSEs: combined_positional_and_structural_encodings + - liftings@_here_: ${get_required_lifting:${dataset},${model}} diff --git a/configs/transforms/model_dataset_defaults/hopse_g_PROTEINS.yaml b/configs/transforms/model_dataset_defaults/hopse_g_PROTEINS.yaml new file mode 100644 index 000000000..d24682390 --- /dev/null +++ b/configs/transforms/model_dataset_defaults/hopse_g_PROTEINS.yaml @@ -0,0 +1,4 @@ +defaults: + - data_manipulations: identity # PROTEINS dataset needs identity transform to avoid adding random float feature to feature matrix + - liftings@_here_: ${get_required_lifting:${dataset},${model}} + - data_manipulations@hopse_encoding: add_gpse_information diff --git a/configs/transforms/model_dataset_defaults/hopse_m_PROTEINS.yaml b/configs/transforms/model_dataset_defaults/hopse_m_PROTEINS.yaml new file mode 100644 index 000000000..dca7dab5e --- /dev/null +++ b/configs/transforms/model_dataset_defaults/hopse_m_PROTEINS.yaml @@ -0,0 +1,4 @@ +defaults: + - data_manipulations: identity # PROTEINS dataset needs identity transform to avoid adding random float feature to feature matrix + - liftings@_here_: ${get_required_lifting:${dataset},${model}} + - data_manipulations@hopse_encoding: hopse_ps_information diff --git a/configs/transforms/model_dataset_defaults/hopse_m_ZINC.yaml b/configs/transforms/model_dataset_defaults/hopse_m_ZINC.yaml index 4b7a42c05..1117fa0cd 100644 --- a/configs/transforms/model_dataset_defaults/hopse_m_ZINC.yaml +++ b/configs/transforms/model_dataset_defaults/hopse_m_ZINC.yaml @@ -3,34 +3,6 @@ defaults: - liftings@_here_: ${get_required_lifting:${dataset},${model}} - data_manipulations@hopse_encoding: hopse_ps_information -hopse_encoding: - pe_types: - - 'RWSE' - - 'ElstaticPE' - - 'HKdiagSE' - - 'LapPE' - - - # Different PS have different sizes, need to unify them. - target_pe_dim: 20 - - # LapPE config - laplacian_norm_type: 'sym' - posenc_LapPE_eigen_max_freqs: 18 - posenc_LapPE_eigen_eigvec_norm: 'L2' - posenc_LapPE_eigen_skip_zero_freq: True - posenc_LapPE_eigen_eigvec_abs: True - - # RWSE config - kernel_param_RWSE: - - 2 - - 20 - - # HKdiagSE config - kernel_param_HKdiagSE: - - 1 - - 22 - one_hot_node_degree_features: degrees_field: x features_field: x diff --git a/configs/transforms/model_dataset_defaults/sann_PROTEINS.yaml b/configs/transforms/model_dataset_defaults/sann_PROTEINS.yaml new file mode 100644 index 000000000..37e3e7737 --- /dev/null +++ b/configs/transforms/model_dataset_defaults/sann_PROTEINS.yaml @@ -0,0 +1,4 @@ +defaults: + - data_manipulations: identity # PROTEINS dataset needs identity transform to avoid adding random float feature to feature matrix + - liftings@_here_: ${get_required_lifting:${dataset},${model}} + - data_manipulations@sann_encoding: precompute_khop_features diff --git a/pyproject.toml b/pyproject.toml index 8622d5125..fce9fc188 100755 --- a/pyproject.toml +++ b/pyproject.toml @@ -60,6 +60,10 @@ dependencies=[ "torch-scatter", "torch-sparse", "torch-cluster", + "rdkit-pypi", + "PyTDC==1.1.15", + # PyTDC imports pkg_resources; setuptools>=82 dropped it. + "setuptools>=69,<82", ] [project.optional-dependencies] diff --git a/scripts/cwn.sh b/scripts/cwn.sh new file mode 100644 index 000000000..8cf355f0d --- /dev/null +++ b/scripts/cwn.sh @@ -0,0 +1,425 @@ +#!/bin/bash +# ============================================================================== +# SCRIPT: cwn.sh +# DESCRIPTION: +# Runs a scalable hyperparameter sweep for CWN (Cell Weisfeiler-Nahman +# Network) across graph datasets lifted to cell complexes. +# - ARCHITECTURE: Uses a "Cartesian Product" generation strategy. +# - CONCURRENCY: Uses "Virtual Slots" to run N jobs per GPU. +# - ORDERING: Prioritizes running all seeds for a config before moving on. +# - FILTERING: Transductive datasets forced to batch_size=1. +# ============================================================================== + +export SELECTED_GPUS="0,1,2,3,4,5,6,7" +wandb_entity="gbg141-hopse" +RESUME=true # Set to true to skip already-completed runs (reads SUCCESSFUL_RUNS.log) + +# ============================================================================== +# SECTION 1: LOGGING & ENVIRONMENT SETUP +# ============================================================================== + +# 1.1 Define Project Identifiers +script_name="$(basename "${BASH_SOURCE[0]}" .sh)" +project_name="${script_name}" +log_group="cwn_sweep" +LOG_DIR="./logs/${log_group}" + +echo "==========================================================" +echo " Preparing log directory: $LOG_DIR" +echo "==========================================================" + +# 1.2 Log directory management +if [[ "$RESUME" == "true" ]]; then + echo "⏩ RESUME MODE: Keeping existing logs." + mkdir -p "$LOG_DIR" +else + if [ -d "$LOG_DIR" ]; then rm -r "$LOG_DIR"; fi + mkdir -p "$LOG_DIR" +fi + +# 1.3 Robust Dependency Loading +SCRIPT_DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" &> /dev/null && pwd )" +export HYDRA_FULL_ERROR=1 + +find_logging_script() { + local dir="$1" + while [[ "$dir" != "/" ]]; do + if [[ -f "$dir/base/logging.sh" ]]; then echo "$dir/base/logging.sh"; return 0; fi + if [[ -f "$dir/scripts/base/logging.sh" ]]; then echo "$dir/scripts/base/logging.sh"; return 0; fi + dir="$(dirname "$dir")" + done + return 1 +} + +LOGGING_PATH=$(find_logging_script "$SCRIPT_DIR") +if [[ -n "$LOGGING_PATH" ]]; then + echo "✔ Found logging utils at: $LOGGING_PATH" + source "$LOGGING_PATH" +else + echo "❌ CRITICAL ERROR: Could not locate 'base/logging.sh'." + exit 1 +fi + +# ============================================================================== +# CPU THREAD LIMITS (Crucial for concurrency) +# ============================================================================== +export OMP_NUM_THREADS=1 +export MKL_NUM_THREADS=1 +export OPENBLAS_NUM_THREADS=1 +export VECLIB_MAXIMUM_THREADS=1 +export NUMEXPR_NUM_THREADS=1 + +# ============================================================================== +# SECTION 2: HARDWARE & CONCURRENCY (Auto-Detected) +# ============================================================================== + +# 2.1 Auto-detect GPUs and determine jobs-per-GPU from VRAM. +# Thresholds: >= 80 GB -> 5 jobs, <= 10 GB -> 1 job, <= 30 GB -> 2 jobs, else 3. +_gpu_info=$(python3 -c " +import subprocess +import os + +selected_env = os.environ.get('SELECTED_GPUS', '').strip() +allowed_gpus = [x.strip() for x in selected_env.split(',')] if selected_env else None + +try: + out = subprocess.check_output( + ['nvidia-smi', '--query-gpu=index,memory.total', '--format=csv,noheader,nounits'], + text=True + ) + indices, mem_mb = [], [] + for line in out.strip().splitlines(): + idx, mem = line.split(',') + idx = idx.strip() + if allowed_gpus and idx not in allowed_gpus: + continue + indices.append(idx) + mem_mb.append(int(mem.strip())) + if not indices: + print('0') + exit(0) + min_mem_gb = min(mem_mb) / 1024 + if min_mem_gb >= 80: + jobs = 5 + elif min_mem_gb <= 10: + jobs = 1 + elif min_mem_gb <= 30: + jobs = 2 + else: + jobs = 3 + print(jobs, ' '.join(indices)) +except Exception: + print('2 0') +") +read -r JOBS_PER_GPU _gpu_ids <<< "$_gpu_info" +read -ra physical_gpus <<< "$_gpu_ids" + +echo "✔ Detected ${#physical_gpus[@]} GPU(s): ${physical_gpus[*]}" +echo "✔ Jobs per GPU: $JOBS_PER_GPU" + +# 2.2 Create Virtual Slots +gpus=() +for gpu in "${physical_gpus[@]}"; do + for ((i=1; i<=JOBS_PER_GPU; i++)); do gpus+=("$gpu"); done +done +echo "✔ Total virtual slots: ${#gpus[@]}" + +# 2.3 Initialize Slot Tracking +declare -a slot_pids +for i in "${!gpus[@]}"; do slot_pids[$i]=0; done + + +# ============================================================================== +# SECTION 3: EXPERIMENT PARAMETERS +# ============================================================================== + +# --- Model --- +# CWN is a cell-domain model; it is applied to graph datasets via graph2cell lifting. +models=( + "cell/cwn" +) + +# --- Datasets --- +# CWN requires cell-complex data. Only graph datasets are included here because +# they can be lifted to cell complexes (via cycle lifting). Simplicial datasets +# (mantra) cannot be lifted to cell complexes in this framework. +datasets=( + "graph/MUTAG" + "graph/PROTEINS" + "graph/NCI1" + "graph/NCI109" + "graph/BBB_Martins" + "graph/Caco2_Wang" + "graph/Clearance_Hepatocyte_AZ" + "graph/CYP3A4_Veith" + "graph/cocitation_cora" + "graph/cocitation_citeseer" + "graph/cocitation_pubmed" + "graph/ZINC" +) + +# --- Hyperparameters --- +num_layers=(1 2 4) +hidden_channels=(64 128 256) +proj_dropouts=(0.0 0.25) +lrs=(0.01 0.001) +weight_decays=(0.0001) +batch_sizes=(128 256) +DATA_SEEDS=(0 3 5 7 9) + +# --- Fixed Parameters --- +FIXED_ARGS=( + "trainer.max_epochs=500" + "trainer.min_epochs=50" + "trainer.check_val_every_n_epoch=5" + "callbacks.early_stopping.patience=10" + "delete_checkpoint_after_test=True" +) + + +# ============================================================================== +# SECTION 4: SWEEP CONFIGURATION MAPPING +# Format: "ShortTag | HydraKey | ${Array[*]}" +# ============================================================================== + +SWEEP_CONFIG=( + # --- LEVEL 1: SLOWEST CHANGING (Outer Loops) --- + "|model|${models[*]}" + "|dataset|${datasets[*]}" + + # --- LEVEL 2: HYPERPARAMETERS --- + "L|model.backbone.n_layers|${num_layers[*]}" + "h|model.feature_encoder.out_channels|${hidden_channels[*]}" + "pdro|model.feature_encoder.proj_dropout|${proj_dropouts[*]}" + "lr|optimizer.parameters.lr|${lrs[*]}" + "wd|optimizer.parameters.weight_decay|${weight_decays[*]}" + "bs|dataset.dataloader_params.batch_size|${batch_sizes[*]}" + + # --- LEVEL 3: FASTEST CHANGING (Inner Loop) --- + "seed|dataset.split_params.data_seed|${DATA_SEEDS[*]}" +) + + +# ============================================================================== +# SECTION 5: PYTHON GENERATOR (Transductive Filtering) +# ============================================================================== + +export CONFIG_DIR="./configs/dataset" + +generate_combinations() { +python3 -c " +import sys, itertools, os + +config_dir = os.environ.get('CONFIG_DIR', './configs/dataset') + +# 1. Parse Input Specs +specs = [] +for item in sys.argv[1:]: + parts = item.split('|') + tag = parts[0].strip() + key = parts[1].strip() + vals = parts[2].split() + specs.append({'tag': tag, 'key': key, 'vals': vals}) + +# 2. Generate Cartesian Product +options = [[(s['tag'], s['key'], val) for val in s['vals']] for s in specs] +combinations = list(itertools.product(*options)) + +# Helper to strip alias +def hydra_val(v): + return v.split('::', 1)[1] if '::' in v else v + +# Find the first batch size to avoid duplicating transductive runs +bs_key = 'dataset.dataloader_params.batch_size' +bs_spec = next((s for s in specs if s['key'] == bs_key), None) +first_bs = hydra_val(bs_spec['vals'][0]) if bs_spec else None + +# 3. Filter and Mutate Combos +valid = [] +skipped = 0 +transductive_cache = {} + +for combo in combinations: + vals_dict = {key: hydra_val(val) for (_, key, val) in combo} + dataset_val = vals_dict.get('dataset', '') + current_bs = vals_dict.get(bs_key, '') + + # --- Transductive Batch Size Handler --- + if dataset_val in transductive_cache: + is_transductive = transductive_cache[dataset_val] + else: + is_transductive = False + yaml_path = os.path.join(config_dir, f'{dataset_val}.yaml') + if os.path.exists(yaml_path): + with open(yaml_path, 'r') as f: + if 'learning_setting: transductive' in f.read(): + is_transductive = True + else: + print(f'WARNING: Could not find config at {yaml_path}', file=sys.stderr) + transductive_cache[dataset_val] = is_transductive + + if is_transductive: + if current_bs != first_bs: + skipped += 1 + continue + new_combo = [] + for (tag, key, val) in combo: + if key == bs_key: + new_combo.append((tag, key, '1')) + else: + new_combo.append((tag, key, val)) + combo = tuple(new_combo) + + valid.append(combo) + +# 4. Print header +print(f'TOTAL;{len(valid)}') +if skipped: + print(f'SKIPPED;{skipped}', file=sys.stderr) + +# 5. Print each valid combination +for combo in valid: + name_parts = [] + cmd_args = [] + for (tag, key, val) in combo: + if '::' in val: + alias, actual_val = val.split('::', 1) + clean_val = alias + else: + clean_val = os.path.basename(val) + actual_val = val + if tag: + name_parts.append(f'{tag}{clean_val}') + else: + name_parts.append(clean_val) + cmd_args.append(f'{key}={actual_val}') + + run_name = '_'.join(name_parts) + print(f'{run_name};' + ' '.join(cmd_args)) +" "${SWEEP_CONFIG[@]}" +} + + +# ============================================================================== +# SECTION 5.5: RESUME — LOAD COMPLETED RUNS +# ============================================================================== + +declare -A _completed_runs +if [[ "$RESUME" == "true" ]]; then + _success_log="$LOG_DIR/$log_group/SUCCESSFUL_RUNS.log" + if [[ -f "$_success_log" ]]; then + while IFS= read -r _line; do + _rname="${_line##*\[SUCCESS\] }" + _completed_runs["$_rname"]=1 + done < "$_success_log" + echo "✔ Loaded ${#_completed_runs[@]} completed runs to skip." + else + echo "⚠️ No SUCCESSFUL_RUNS.log found at $_success_log — nothing to skip." + fi +fi + + +# ============================================================================== +# SECTION 6: MAIN EXECUTION LOOP +# ============================================================================== + +echo "----------------------------------------------------------" +echo " Generating experiment combinations..." +echo "----------------------------------------------------------" + +total_runs=0 +run_counter=0 +skipped_completed=0 +one_percent_step=1 + +while IFS=";" read -r col1 col2; do + + # 6.1 Handle Header + if [[ "$col1" == "TOTAL" ]]; then + total_runs=$col2 + if [ "$total_runs" -gt 0 ]; then + one_percent_step=$(( total_runs / 100 )) + fi + if [ "$one_percent_step" -eq 0 ]; then one_percent_step=1; fi + echo "► Total runs planned: $total_runs" + echo "► Reporting progress every $one_percent_step runs (1%)" + echo "----------------------------------------------------------" + continue + fi + + # 6.2 Parse Run Data + run_name="$col1" + dynamic_args_str="$col2" + + # 6.2.1 Skip if already completed (RESUME mode) + if [[ "$RESUME" == "true" && -n "${_completed_runs[$run_name]+x}" ]]; then + ((skipped_completed++)) + continue + fi + + # 6.3 Update Progress + ((run_counter++)) + if (( run_counter % one_percent_step == 0 )); then + if [ "$total_runs" -gt 0 ]; then + percent=$(( (run_counter * 100) / total_runs )) + else + percent=0 + fi + echo "📊 Progress: ${percent}% completed ($run_counter / $total_runs runs launched)" + fi + + # 6.4 Find a Free GPU Slot + assigned_slot=-1 + while [ "$assigned_slot" -eq -1 ]; do + for i in "${!gpus[@]}"; do + pid="${slot_pids[$i]}" + if [ "$pid" -eq 0 ] || ! kill -0 "$pid" 2>/dev/null; then + assigned_slot=$i + break + fi + done + if [ "$assigned_slot" -eq -1 ]; then + wait -n + fi + done + + # 6.5 Prepare Command + current_gpu=${gpus[$assigned_slot]} + read -ra DYNAMIC_ARGS_ARRAY <<< "$dynamic_args_str" + + # Extract dataset name for dynamic W&B project + dataset_val="" + for arg in "${DYNAMIC_ARGS_ARRAY[@]}"; do + if [[ $arg == dataset=* ]]; then + dataset_full_path="${arg#*=}" + dataset_val=$(basename "$dataset_full_path") + break + fi + done + dynamic_project_name="${project_name}_${dataset_val}" + + cmd=( + "python" "-m" "topobench" + "${DYNAMIC_ARGS_ARRAY[@]}" + "${FIXED_ARGS[@]}" + "trainer.devices=[${current_gpu}]" + "+logger.wandb.entity=${wandb_entity}" + "logger.wandb.project=${dynamic_project_name}" + "+logger.wandb.name=${run_name}" + ) + + # 6.6 Execute + run_and_log "${cmd[*]}" "$log_group" "$run_name" "$LOG_DIR" & + slot_pids[$assigned_slot]=$! + +done < <(generate_combinations) + + +# ============================================================================== +# SECTION 7: CLEANUP +# ============================================================================== +echo "----------------------------------------------------------" +echo " All jobs launched ($run_counter total, $skipped_completed skipped as already completed)." +echo " Waiting for remaining background jobs to finish..." +echo "----------------------------------------------------------" +wait diff --git a/scripts/gat.sh b/scripts/gat.sh new file mode 100644 index 000000000..005deb814 --- /dev/null +++ b/scripts/gat.sh @@ -0,0 +1,451 @@ +#!/bin/bash +# ============================================================================== +# SCRIPT: gat.sh +# DESCRIPTION: +# Runs a scalable hyperparameter sweep for HOPSE_M models across both +# simplicial and cellular domains. +# - ARCHITECTURE: Uses a "Cartesian Product" generation strategy. +# - CONCURRENCY: Uses "Virtual Slots" to run N jobs per GPU. +# - ORDERING: Prioritizes running all seeds for a config before moving on. +# - FILTERING: Skips invalid model+dataset combos (cell + simplicial data). +# ============================================================================== + +# ============================================================================== +# SECTION 1: LOGGING & ENVIRONMENT SETUP +# ============================================================================== + +# 1.1 Define Project Identifiers +script_name="$(basename "${BASH_SOURCE[0]}" .sh)" +project_name="${script_name}" +log_group="gat_sweep" +LOG_DIR="./logs/${log_group}" +wandb_entity="gbg141-hopse" + +echo "==========================================================" +echo " Preparing log directory: $LOG_DIR" +echo "==========================================================" + +# 1.2 Clean up old logs to ensure a fresh run +if [ -d "$LOG_DIR" ]; then rm -r "$LOG_DIR"; fi +mkdir -p "$LOG_DIR" + +# 1.3 Robust Dependency Loading +SCRIPT_DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" &> /dev/null && pwd )" +export HYDRA_FULL_ERROR=1 + +find_logging_script() { + local dir="$1" + while [[ "$dir" != "/" ]]; do + if [[ -f "$dir/base/logging.sh" ]]; then echo "$dir/base/logging.sh"; return 0; fi + if [[ -f "$dir/scripts/base/logging.sh" ]]; then echo "$dir/scripts/base/logging.sh"; return 0; fi + dir="$(dirname "$dir")" + done + return 1 +} + +LOGGING_PATH=$(find_logging_script "$SCRIPT_DIR") +if [[ -n "$LOGGING_PATH" ]]; then + echo "✔ Found logging utils at: $LOGGING_PATH" + source "$LOGGING_PATH" +else + echo "❌ CRITICAL ERROR: Could not locate 'base/logging.sh'." + exit 1 +fi + + +# ============================================================================== +# SECTION 2: HARDWARE & CONCURRENCY (Auto-Detected) +# ============================================================================== + +# 2.1 Auto-detect GPUs and determine jobs-per-GPU from VRAM. +# Output format: "JOBS_PER_GPU gpu_id_0 gpu_id_1 ..." +# Thresholds: >= 80 GB -> 4 jobs, <= 30 GB -> 2 jobs, between -> 3 jobs. +export SELECTED_GPUS="2,3,4,5,6,7" + +_gpu_info=$(python3 -c " +import subprocess +import os + +# 1. Read the allowed GPUs from the environment variable +selected_env = os.environ.get('SELECTED_GPUS', '').strip() +allowed_gpus = [x.strip() for x in selected_env.split(',')] if selected_env else None + +try: + out = subprocess.check_output( + ['nvidia-smi', '--query-gpu=index,memory.total', '--format=csv,noheader,nounits'], + text=True + ) + indices, mem_mb = [], [] + for line in out.strip().splitlines(): + idx, mem = line.split(',') + idx = idx.strip() + + # 2. Skip this GPU if it's not in our selected list + if allowed_gpus and idx not in allowed_gpus: + continue + + indices.append(idx) + mem_mb.append(int(mem.strip())) + + # Safety check in case the selected GPUs don't exist + if not indices: + print('0') + exit(0) + + min_mem_gb = min(mem_mb) / 1024 + if min_mem_gb >= 80: + jobs = 4 + elif min_mem_gb <= 30: + jobs = 2 + else: + jobs = 3 + + print(jobs, ' '.join(indices)) +except Exception: + print('2 0') +") +read -r JOBS_PER_GPU _gpu_ids <<< "$_gpu_info" +read -ra physical_gpus <<< "$_gpu_ids" + +echo "✔ Detected ${#physical_gpus[@]} GPU(s): ${physical_gpus[*]}" +echo "✔ Jobs per GPU: $JOBS_PER_GPU" + +# 2.2 Create Virtual Slots +gpus=() +for gpu in "${physical_gpus[@]}"; do + for ((i=1; i<=JOBS_PER_GPU; i++)); do gpus+=("$gpu"); done +done +echo "✔ Total virtual slots: ${#gpus[@]}" + +# 2.3 Initialize Slot Tracking +declare -a slot_pids +for i in "${!gpus[@]}"; do slot_pids[$i]=0; done + + +# ============================================================================== +# SECTION 3: EXPERIMENT PARAMETERS +# ============================================================================== + +# --- Models --- +models=( + "gat::graph/gat" +) + +# --- Datasets --- +datasets=( + "graph/MUTAG" + "graph/cocitation_cora" + "graph/PROTEINS" + "graph/NCI1" + "graph/NCI109" + "graph/ZINC" + "graph/cocitation_citeseer" + "graph/cocitation_pubmed" + "simplicial/mantra_name" + "simplicial/mantra_orientation" + "simplicial/mantra_betti_numbers" +) + +# --- Transforms (Hydra: configs/transforms/.yaml) --- +# combined_pe / combined_fe nest data_manipulations under CombinedPSEs / CombinedFEs (encoding lists in those YAMLs). +# Extra Hydra flags: @@@ between full key=value pieces, e.g. +# "pse::combined_pe@@@transforms.CombinedPSEs.encodings=[LapPE,RWSE]" +transform_presets=( + "notf::no_transform" + "pse::combined_pe@@@transforms.CombinedPSEs.encodings=[LapPE,RWSE,ElectrostaticPE,HKdiagSE]" + "fe::combined_fe@@@transforms.CombinedFEs.encodings=[HKFE,KHopFE,PPRFE]" +) + +# --- Hyperparameters (superset across all dataset groups) --- +num_layers=(1 2 4) +hidden_channels=(128 256) +proj_dropouts=(0.25 0.5) +lrs=(0.01 0.001) +weight_decays=(0 0.0001) +batch_sizes=(128 256) +DATA_SEEDS=(0 3 5 7 9) + +# --- Fixed Parameters --- +FIXED_ARGS=( + "trainer.max_epochs=500" + "trainer.min_epochs=50" + "trainer.check_val_every_n_epoch=5" + "callbacks.early_stopping.patience=5" +) + + +# ============================================================================== +# SECTION 4: SWEEP CONFIGURATION MAPPING (CRITICAL ORDERING) +# Format: "ShortTag | HydraKey | ${Array[*]}" +# +# Values support an optional "alias::hydra_value" syntax for readable names. +# Use @@@ in hydra_value to emit several space-separated CLI overrides (see transform_presets). +# The generator also filters out invalid model+dataset combos. +# ============================================================================== + +SWEEP_CONFIG=( + # --- LEVEL 1: SLOWEST CHANGING (Outer Loops) --- + "|model|${models[*]}" + "|dataset|${datasets[*]}" + "tf|transforms|${transform_presets[*]}" + + # --- LEVEL 2: HYPERPARAMETERS --- + "L|model.backbone.num_layers|${num_layers[*]}" + "h|model.feature_encoder.out_channels|${hidden_channels[*]}" + "pdro|model.feature_encoder.proj_dropout|${proj_dropouts[*]}" + "lr|optimizer.parameters.lr|${lrs[*]}" + "wd|optimizer.parameters.weight_decay|${weight_decays[*]}" + "bs|dataset.dataloader_params.batch_size|${batch_sizes[*]}" + + # --- LEVEL 3: FASTEST CHANGING (Inner Loop) --- + "seed|dataset.split_params.data_seed|${DATA_SEEDS[*]}" +) + + +# ============================================================================== +# SECTION 5: PYTHON GENERATOR (Smart Transductive Filtering) +# ============================================================================== + +# Define where your dataset YAMLs live so the generator can inspect them. +# UPDATE THIS PATH IF YOUR CONFIGS ARE STORED ELSEWHERE. +export CONFIG_DIR="./configs/dataset" + +generate_combinations() { +python3 -c " +import sys, itertools, os + +config_dir = os.environ.get('CONFIG_DIR', './configs/dataset') + +# 1. Parse Input Specs +specs = [] +for item in sys.argv[1:]: + parts = item.split('|') + tag = parts[0].strip() + key = parts[1].strip() + vals = parts[2].split() + specs.append({'tag': tag, 'key': key, 'vals': vals}) + +# 2. Generate Cartesian Product +options = [[(s['tag'], s['key'], val) for val in s['vals']] for s in specs] +combinations = list(itertools.product(*options)) + +# Helper to strip alias +def hydra_val(v): + return v.split('::', 1)[1] if '::' in v else v + +# Find the first batch size in the sweep so we don't duplicate transductive runs +bs_key = 'dataset.dataloader_params.batch_size' +bs_spec = next((s for s in specs if s['key'] == bs_key), None) +first_bs = hydra_val(bs_spec['vals'][0]) if bs_spec else None + +# 3. Filter and Mutate Combos +valid = [] +skipped = 0 +transductive_cache = {} + +for combo in combinations: + vals_dict = {key: hydra_val(val) for (_, key, val) in combo} + model_val = vals_dict.get('model', '') + dataset_val = vals_dict.get('dataset', '') + current_bs = vals_dict.get(bs_key, '') + + # --- Rule A: Skip cell model + simplicial dataset --- + if model_val.startswith('cell/') and dataset_val.startswith('simplicial/'): + skipped += 1 + continue + + # --- Rule B: Transductive Batch Size Handler --- + is_transductive = False + if dataset_val in transductive_cache: + is_transductive = transductive_cache[dataset_val] + else: + # Construct path to yaml (e.g., ./configs/dataset/graph/cocitation_cora.yaml) + yaml_path = os.path.join(config_dir, f'{dataset_val}.yaml') + if os.path.exists(yaml_path): + with open(yaml_path, 'r') as f: + # Fast text check avoids needing pip install pyyaml + if 'learning_setting: transductive' in f.read(): + is_transductive = True + else: + print(f'⚠️ WARNING: Could not find config at {yaml_path}', file=sys.stderr) + + transductive_cache[dataset_val] = is_transductive + + if is_transductive: + # If this isn't the first batch size in the sweep list, skip it + # to avoid running the exact same bs=1 experiment multiple times. + if current_bs != first_bs: + skipped += 1 + continue + + # Mutate the current combination to force batch_size to 1 + new_combo = [] + for (tag, key, val) in combo: + if key == bs_key: + # Force the value to 1. If an alias was used, keep it clean. + new_combo.append((tag, key, '1')) + else: + new_combo.append((tag, key, val)) + combo = tuple(new_combo) + + valid.append(combo) + +# 4. Print header +print(f'TOTAL;{len(valid)}') +if skipped: + print(f'SKIPPED;{skipped}', file=sys.stderr) + +# 5. Print each valid combination +for combo in valid: + name_parts = [] + cmd_args = [] + for (tag, key, val) in combo: + if '::' in val: + alias, hydra_val_str = val.split('::', 1) + clean_val = alias + actual_val = hydra_val_str + else: + clean_val = os.path.basename(val) + actual_val = val + + if tag: + name_parts.append(f'{tag}{clean_val}') + else: + name_parts.append(clean_val) + if '@@@' in actual_val: + for part in actual_val.split('@@@'): + part = part.strip() + if part: + cmd_args.append(part) + else: + cmd_args.append(f'{key}={actual_val}') + + run_name = '_'.join(name_parts) + print(f'{run_name};' + ' '.join(cmd_args)) +" "${SWEEP_CONFIG[@]}" +} + +# ============================================================================== +# SECTION 6: MAIN EXECUTION LOOP +# ============================================================================== + +# If IFS was polluted, read can split transforms=combined_pe; Hydra then errors on bare "combined_pe". +repair_hydra_transforms_arg() { + local -n _r=$1 + local out=() i + for ((i = 0; i < ${#_r[@]}; i++)); do + local t="${_r[i]}" + if [[ "$t" == transforms=* ]]; then + out+=("$t") + elif [[ "$t" == "transforms" && $((i + 1)) -lt ${#_r[@]} ]]; then + local nxt="${_r[$((i + 1))]}" + [[ "$nxt" == *"="* ]] && { out+=("$t"); continue; } + out+=("transforms=$nxt") + ((i++)) + elif [[ "$t" =~ ^(combined_pe|combined_fe|no_transform)$ ]]; then + out+=("transforms=$t") + else + out+=("$t") + fi + done + _r=("${out[@]}") +} + +echo "----------------------------------------------------------" +echo " Generating experiment combinations..." +echo "----------------------------------------------------------" + +total_runs=0 +run_counter=0 +one_percent_step=1 + +while IFS=";" read -r col1 col2; do + + # 6.1 Handle Header + if [[ "$col1" == "TOTAL" ]]; then + total_runs=$col2 + if [ "$total_runs" -gt 0 ]; then + one_percent_step=$(( total_runs / 100 )) + fi + if [ "$one_percent_step" -eq 0 ]; then one_percent_step=1; fi + + echo "► Total runs planned: $total_runs" + echo "► Reporting progress every $one_percent_step runs (1%)" + echo "----------------------------------------------------------" + continue + fi + + # 6.2 Parse Run Data + run_name="$col1" + dynamic_args_str="$col2" + + # 6.3 Update Progress + ((run_counter++)) + if (( run_counter % one_percent_step == 0 )); then + if [ "$total_runs" -gt 0 ]; then + percent=$(( (run_counter * 100) / total_runs )) + else + percent=0 + fi + echo "📊 Progress: ${percent}% completed ($run_counter / $total_runs runs launched)" + fi + + # 6.4 Find a Free GPU Slot + assigned_slot=-1 + while [ "$assigned_slot" -eq -1 ]; do + for i in "${!gpus[@]}"; do + pid="${slot_pids[$i]}" + if [ "$pid" -eq 0 ] || ! kill -0 "$pid" 2>/dev/null; then + assigned_slot=$i + break + fi + done + if [ "$assigned_slot" -eq -1 ]; then + wait -n + fi + done + + # 6.5 Prepare Command + current_gpu=${gpus[$assigned_slot]} + # Must not use inherited IFS (e.g. IFS== splits transforms=combined_pe → bare "combined_pe" for Hydra) + IFS=$' \t\n' read -ra DYNAMIC_ARGS_ARRAY <<< "$dynamic_args_str" + repair_hydra_transforms_arg DYNAMIC_ARGS_ARRAY + + # --- Extract dataset name for dynamic W&B project --- + dataset_val="" + for arg in "${DYNAMIC_ARGS_ARRAY[@]}"; do + if [[ $arg == dataset=* ]]; then + dataset_full_path="${arg#*=}" + dataset_val=$(basename "$dataset_full_path") + break + fi + done + dynamic_project_name="${project_name}_${dataset_val}" + + cmd=( + "python" "-m" "topobench" + "${DYNAMIC_ARGS_ARRAY[@]}" + "${FIXED_ARGS[@]}" + "trainer.devices=[${current_gpu}]" + "+logger.wandb.entity=${wandb_entity}" + "logger.wandb.project=${dynamic_project_name}" + ) + + # 6.6 Execute — printf %q so run_and_log's eval keeps key=value overrides as single words + # (broken IFS or nullglob in the parent shell can otherwise split transforms=combined_pe, etc.) + cmd_eval=$(printf '%q ' "${cmd[@]}") + run_and_log "${cmd_eval% }" "$log_group" "$run_name" "$LOG_DIR" & + slot_pids[$assigned_slot]=$! + +done < <(generate_combinations) + + +# ============================================================================== +# SECTION 7: CLEANUP +# ============================================================================== +echo "----------------------------------------------------------" +echo " All jobs launched ($run_counter total)." +echo " Waiting for remaining background jobs to finish..." +echo "----------------------------------------------------------" +wait +echo "✔ All runs complete." diff --git a/scripts/gin.sh b/scripts/gin.sh new file mode 100755 index 000000000..3aa9072bb --- /dev/null +++ b/scripts/gin.sh @@ -0,0 +1,439 @@ +#!/bin/bash +# ============================================================================== +# SCRIPT: gin.sh +# DESCRIPTION: +# Runs a scalable hyperparameter sweep for HOPSE_M models across both +# simplicial and cellular domains. +# - ARCHITECTURE: Uses a "Cartesian Product" generation strategy. +# - CONCURRENCY: Uses "Virtual Slots" to run N jobs per GPU. +# - ORDERING: Prioritizes running all seeds for a config before moving on. +# - FILTERING: Skips invalid model+dataset combos (cell + simplicial data). +# ============================================================================== + +# ============================================================================== +# SECTION 1: LOGGING & ENVIRONMENT SETUP +# ============================================================================== + +# 1.1 Define Project Identifiers +script_name="$(basename "${BASH_SOURCE[0]}" .sh)" +project_name="${script_name}" +log_group="gin_sweep" +LOG_DIR="./logs/${log_group}" +wandb_entity="gbg141-hopse" + +echo "==========================================================" +echo " Preparing log directory: $LOG_DIR" +echo "==========================================================" + +# 1.2 Clean up old logs to ensure a fresh run +if [ -d "$LOG_DIR" ]; then rm -r "$LOG_DIR"; fi +mkdir -p "$LOG_DIR" + +# 1.3 Robust Dependency Loading +SCRIPT_DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" &> /dev/null && pwd )" +export HYDRA_FULL_ERROR=1 + +find_logging_script() { + local dir="$1" + while [[ "$dir" != "/" ]]; do + if [[ -f "$dir/base/logging.sh" ]]; then echo "$dir/base/logging.sh"; return 0; fi + if [[ -f "$dir/scripts/base/logging.sh" ]]; then echo "$dir/scripts/base/logging.sh"; return 0; fi + dir="$(dirname "$dir")" + done + return 1 +} + +LOGGING_PATH=$(find_logging_script "$SCRIPT_DIR") +if [[ -n "$LOGGING_PATH" ]]; then + echo "✔ Found logging utils at: $LOGGING_PATH" + source "$LOGGING_PATH" +else + echo "❌ CRITICAL ERROR: Could not locate 'base/logging.sh'." + exit 1 +fi + +# ============================================================================== +# CPU THREAD LIMITS (Crucial for concurrency) +# ============================================================================== +export OMP_NUM_THREADS=1 +export MKL_NUM_THREADS=1 +export OPENBLAS_NUM_THREADS=1 +export VECLIB_MAXIMUM_THREADS=1 +export NUMEXPR_NUM_THREADS=1 +# ============================================================================== + +# ============================================================================== +# SECTION 2: HARDWARE & CONCURRENCY (Auto-Detected) +# ============================================================================== + +# 2.1 Auto-detect GPUs and determine jobs-per-GPU from VRAM. +# Output format: "JOBS_PER_GPU gpu_id_0 gpu_id_1 ..." +# Thresholds: >= 80 GB -> 4 jobs, <= 30 GB -> 2 jobs, between -> 3 jobs. +_gpu_info=$(python3 -c " +import subprocess +try: + out = subprocess.check_output( + ['nvidia-smi', '--query-gpu=index,memory.total', '--format=csv,noheader,nounits'], + text=True + ) + indices, mem_mb = [], [] + for line in out.strip().splitlines(): + idx, mem = line.split(',') + indices.append(idx.strip()) + mem_mb.append(int(mem.strip())) + min_mem_gb = min(mem_mb) / 1024 + if min_mem_gb >= 80: + jobs = 4 + elif min_mem_gb <= 30: + jobs = 4 + else: + jobs = 3 + print(jobs, ' '.join(indices)) +except Exception: + print('2 0') +") +read -r JOBS_PER_GPU _gpu_ids <<< "$_gpu_info" +read -ra physical_gpus <<< "$_gpu_ids" + +echo "✔ Detected ${#physical_gpus[@]} GPU(s): ${physical_gpus[*]}" +echo "✔ Jobs per GPU: $JOBS_PER_GPU" + +# 2.2 Create Virtual Slots +gpus=() +for gpu in "${physical_gpus[@]}"; do + for ((i=1; i<=JOBS_PER_GPU; i++)); do gpus+=("$gpu"); done +done +echo "✔ Total virtual slots: ${#gpus[@]}" + +# 2.3 Initialize Slot Tracking +declare -a slot_pids +for i in "${!gpus[@]}"; do slot_pids[$i]=0; done + + +# ============================================================================== +# SECTION 3: EXPERIMENT PARAMETERS +# ============================================================================== + +# --- Models --- +models=( + "gin::graph/gin" +) + +# --- Datasets --- +datasets=( + "graph/MUTAG" + "graph/PROTEINS" + "graph/NCI1" + "graph/NCI109" + "simplicial/mantra_name" + "simplicial/mantra_orientation" + "simplicial/mantra_betti_numbers" + # "graph/ZINC" + "graph/cocitation_cora" + "graph/cocitation_citeseer" + "graph/cocitation_pubmed" +) + +# --- Transforms (Hydra: configs/transforms/.yaml) --- +# combined_pe / combined_fe nest data_manipulations under CombinedPSEs / CombinedFEs (encoding lists in those YAMLs). +# Extra Hydra flags: @@@ between full key=value pieces, e.g. +# "pse::combined_pe@@@transforms.CombinedPSEs.encodings=[LapPE,RWSE]" +transform_presets=( + "notf::no_transform" + "pse::combined_pe@@@transforms.CombinedPSEs.encodings=[LapPE,RWSE,ElectrostaticPE,HKdiagSE]" + "fe::combined_fe@@@transforms.CombinedFEs.encodings=[HKFE,KHopFE,PPRFE]" +) + +# --- Hyperparameters (superset across all dataset groups) --- +num_layers=(1 2 4) +hidden_channels=(128 256) +proj_dropouts=(0.25 0.5) +lrs=(0.01 0.001) +weight_decays=(0.0001) +batch_sizes=(128 256) +DATA_SEEDS=(0 3 5 7 9) + +# --- Fixed Parameters --- +FIXED_ARGS=( + "trainer.max_epochs=500" + "trainer.min_epochs=50" + "trainer.check_val_every_n_epoch=5" + "callbacks.early_stopping.patience=5" +) + + +# ============================================================================== +# SECTION 4: SWEEP CONFIGURATION MAPPING (CRITICAL ORDERING) +# Format: "ShortTag | HydraKey | ${Array[*]}" +# +# Values support an optional "alias::hydra_value" syntax for readable names. +# Use @@@ in hydra_value to emit several space-separated CLI overrides (see transform_presets). +# The generator also filters out invalid model+dataset combos. +# ============================================================================== + +SWEEP_CONFIG=( + # --- LEVEL 1: SLOWEST CHANGING (Outer Loops) --- + "|model|${models[*]}" + "|dataset|${datasets[*]}" + "tf|transforms|${transform_presets[*]}" + + # --- LEVEL 2: HYPERPARAMETERS --- + "L|model.backbone.num_layers|${num_layers[*]}" + "h|model.feature_encoder.out_channels|${hidden_channels[*]}" + "pdro|model.feature_encoder.proj_dropout|${proj_dropouts[*]}" + "lr|optimizer.parameters.lr|${lrs[*]}" + "wd|optimizer.parameters.weight_decay|${weight_decays[*]}" + "bs|dataset.dataloader_params.batch_size|${batch_sizes[*]}" + + # --- LEVEL 3: FASTEST CHANGING (Inner Loop) --- + "seed|dataset.split_params.data_seed|${DATA_SEEDS[*]}" +) + + +# ============================================================================== +# SECTION 5: PYTHON GENERATOR (Smart Transductive Filtering) +# ============================================================================== + +# Define where your dataset YAMLs live so the generator can inspect them. +# UPDATE THIS PATH IF YOUR CONFIGS ARE STORED ELSEWHERE. +export CONFIG_DIR="./configs/dataset" + +generate_combinations() { +python3 -c " +import sys, itertools, os + +config_dir = os.environ.get('CONFIG_DIR', './configs/dataset') + +# 1. Parse Input Specs +specs = [] +for item in sys.argv[1:]: + parts = item.split('|') + tag = parts[0].strip() + key = parts[1].strip() + vals = parts[2].split() + specs.append({'tag': tag, 'key': key, 'vals': vals}) + +# 2. Generate Cartesian Product +options = [[(s['tag'], s['key'], val) for val in s['vals']] for s in specs] +combinations = list(itertools.product(*options)) + +# Helper to strip alias +def hydra_val(v): + return v.split('::', 1)[1] if '::' in v else v + +# Find the first batch size in the sweep so we don't duplicate transductive runs +bs_key = 'dataset.dataloader_params.batch_size' +bs_spec = next((s for s in specs if s['key'] == bs_key), None) +first_bs = hydra_val(bs_spec['vals'][0]) if bs_spec else None + +# 3. Filter and Mutate Combos +valid = [] +skipped = 0 +transductive_cache = {} + +for combo in combinations: + vals_dict = {key: hydra_val(val) for (_, key, val) in combo} + model_val = vals_dict.get('model', '') + dataset_val = vals_dict.get('dataset', '') + current_bs = vals_dict.get(bs_key, '') + + # --- Rule A: Skip cell model + simplicial dataset --- + if model_val.startswith('cell/') and dataset_val.startswith('simplicial/'): + skipped += 1 + continue + + # --- Rule B: Transductive Batch Size Handler --- + is_transductive = False + if dataset_val in transductive_cache: + is_transductive = transductive_cache[dataset_val] + else: + # Construct path to yaml (e.g., ./configs/dataset/graph/cocitation_cora.yaml) + yaml_path = os.path.join(config_dir, f'{dataset_val}.yaml') + if os.path.exists(yaml_path): + with open(yaml_path, 'r') as f: + # Fast text check avoids needing pip install pyyaml + if 'learning_setting: transductive' in f.read(): + is_transductive = True + else: + print(f'⚠️ WARNING: Could not find config at {yaml_path}', file=sys.stderr) + + transductive_cache[dataset_val] = is_transductive + + if is_transductive: + # If this isn't the first batch size in the sweep list, skip it + # to avoid running the exact same bs=1 experiment multiple times. + if current_bs != first_bs: + skipped += 1 + continue + + # Mutate the current combination to force batch_size to 1 + new_combo = [] + for (tag, key, val) in combo: + if key == bs_key: + # Force the value to 1. If an alias was used, keep it clean. + new_combo.append((tag, key, '1')) + else: + new_combo.append((tag, key, val)) + combo = tuple(new_combo) + + valid.append(combo) + +# 4. Print header +print(f'TOTAL;{len(valid)}') +if skipped: + print(f'SKIPPED;{skipped}', file=sys.stderr) + +# 5. Print each valid combination +for combo in valid: + name_parts = [] + cmd_args = [] + for (tag, key, val) in combo: + if '::' in val: + alias, hydra_val_str = val.split('::', 1) + clean_val = alias + actual_val = hydra_val_str + else: + clean_val = os.path.basename(val) + actual_val = val + + if tag: + name_parts.append(f'{tag}{clean_val}') + else: + name_parts.append(clean_val) + if '@@@' in actual_val: + for part in actual_val.split('@@@'): + part = part.strip() + if part: + cmd_args.append(part) + else: + cmd_args.append(f'{key}={actual_val}') + + run_name = '_'.join(name_parts) + print(f'{run_name};' + ' '.join(cmd_args)) +" "${SWEEP_CONFIG[@]}" +} + +# ============================================================================== +# SECTION 6: MAIN EXECUTION LOOP +# ============================================================================== + +# If IFS was polluted, read can split transforms=combined_pe; Hydra then errors on bare "combined_pe". +repair_hydra_transforms_arg() { + local -n _r=$1 + local out=() i + for ((i = 0; i < ${#_r[@]}; i++)); do + local t="${_r[i]}" + if [[ "$t" == transforms=* ]]; then + out+=("$t") + elif [[ "$t" == "transforms" && $((i + 1)) -lt ${#_r[@]} ]]; then + local nxt="${_r[$((i + 1))]}" + [[ "$nxt" == *"="* ]] && { out+=("$t"); continue; } + out+=("transforms=$nxt") + ((i++)) + elif [[ "$t" =~ ^(combined_pe|combined_fe|no_transform)$ ]]; then + out+=("transforms=$t") + else + out+=("$t") + fi + done + _r=("${out[@]}") +} + +echo "----------------------------------------------------------" +echo " Generating experiment combinations..." +echo "----------------------------------------------------------" + +total_runs=0 +run_counter=0 +one_percent_step=1 + +while IFS=";" read -r col1 col2; do + + # 6.1 Handle Header + if [[ "$col1" == "TOTAL" ]]; then + total_runs=$col2 + if [ "$total_runs" -gt 0 ]; then + one_percent_step=$(( total_runs / 100 )) + fi + if [ "$one_percent_step" -eq 0 ]; then one_percent_step=1; fi + + echo "► Total runs planned: $total_runs" + echo "► Reporting progress every $one_percent_step runs (1%)" + echo "----------------------------------------------------------" + continue + fi + + # 6.2 Parse Run Data + run_name="$col1" + dynamic_args_str="$col2" + + # 6.3 Update Progress + ((run_counter++)) + if (( run_counter % one_percent_step == 0 )); then + if [ "$total_runs" -gt 0 ]; then + percent=$(( (run_counter * 100) / total_runs )) + else + percent=0 + fi + echo "📊 Progress: ${percent}% completed ($run_counter / $total_runs runs launched)" + fi + + # 6.4 Find a Free GPU Slot + assigned_slot=-1 + while [ "$assigned_slot" -eq -1 ]; do + for i in "${!gpus[@]}"; do + pid="${slot_pids[$i]}" + if [ "$pid" -eq 0 ] || ! kill -0 "$pid" 2>/dev/null; then + assigned_slot=$i + break + fi + done + if [ "$assigned_slot" -eq -1 ]; then + wait -n + fi + done + + # 6.5 Prepare Command + current_gpu=${gpus[$assigned_slot]} + # Must not use inherited IFS (e.g. IFS== splits transforms=combined_pe → bare "combined_pe" for Hydra) + IFS=$' \t\n' read -ra DYNAMIC_ARGS_ARRAY <<< "$dynamic_args_str" + repair_hydra_transforms_arg DYNAMIC_ARGS_ARRAY + + # --- Extract dataset name for dynamic W&B project --- + dataset_val="" + for arg in "${DYNAMIC_ARGS_ARRAY[@]}"; do + if [[ $arg == dataset=* ]]; then + dataset_full_path="${arg#*=}" + dataset_val=$(basename "$dataset_full_path") + break + fi + done + dynamic_project_name="${project_name}_${dataset_val}" + + cmd=( + "python" "-m" "topobench" + "${DYNAMIC_ARGS_ARRAY[@]}" + "${FIXED_ARGS[@]}" + "trainer.devices=[${current_gpu}]" + "+logger.wandb.entity=${wandb_entity}" + "logger.wandb.project=${dynamic_project_name}" + ) + + # 6.6 Execute — printf %q so run_and_log's eval keeps key=value overrides as single words + # (broken IFS or nullglob in the parent shell can otherwise split transforms=combined_pe, etc.) + cmd_eval=$(printf '%q ' "${cmd[@]}") + run_and_log "${cmd_eval% }" "$log_group" "$run_name" "$LOG_DIR" & + slot_pids[$assigned_slot]=$! + +done < <(generate_combinations) + + +# ============================================================================== +# SECTION 7: CLEANUP +# ============================================================================== +echo "----------------------------------------------------------" +echo " All jobs launched ($run_counter total)." +echo " Waiting for remaining background jobs to finish..." +echo "----------------------------------------------------------" +wait +echo "✔ All runs complete." diff --git a/scripts/hopse_g.sh b/scripts/hopse_g.sh new file mode 100644 index 000000000..a0f92f9a9 --- /dev/null +++ b/scripts/hopse_g.sh @@ -0,0 +1,486 @@ +#!/bin/bash +# ============================================================================== +# SCRIPT: hopse_g_baselines.sh +# DESCRIPTION: +# Runs a scalable hyperparameter sweep for HOPSE_G models across both +# simplicial and cellular domains. +# - ARCHITECTURE: Uses a "Cartesian Product" generation strategy. +# - CONCURRENCY: Uses "Virtual Slots" to run N jobs per GPU. +# - ORDERING: Prioritizes running all seeds for a config before moving on. +# - FILTERING: Skips invalid model+dataset combos (cell + simplicial data). +# ============================================================================== +# DO NOT MISS THIS + +export SELECTED_GPUS="2,3,4,5,6,7" +wandb_entity="gbg141-hopse" +RESUME=true # Set to true to skip already-completed runs (reads SUCCESSFUL_RUNS.log) + +# ============================================================================== +# SECTION 1: LOGGING & ENVIRONMENT SETUP +# ============================================================================== + +# Kill all background child processes if this script is interrupted or killed +trap 'echo -e "\n🛑 Interrupted! Cleaning up all background jobs..."; kill 0 2>/dev/null; exit 1' SIGINT SIGTERM + +# 1.1 Define Project Identifiers +script_name="$(basename "${BASH_SOURCE[0]}" .sh)" +project_name="${script_name}" +log_group="hopse_g_sweep" +LOG_DIR="./logs/${log_group}" + +echo "==========================================================" +echo " Preparing log directory: $LOG_DIR" +echo "==========================================================" + +# 1.2 Log directory management +if [[ "$RESUME" == "true" ]]; then + echo "⏩ RESUME MODE: Keeping existing logs." + mkdir -p "$LOG_DIR" +else + if [ -d "$LOG_DIR" ]; then rm -r "$LOG_DIR"; fi + mkdir -p "$LOG_DIR" +fi + +# 1.3 Robust Dependency Loading +SCRIPT_DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" &> /dev/null && pwd )" +export HYDRA_FULL_ERROR=1 + +find_logging_script() { + local dir="$1" + while [[ "$dir" != "/" ]]; do + if [[ -f "$dir/base/logging.sh" ]]; then echo "$dir/base/logging.sh"; return 0; fi + if [[ -f "$dir/scripts/base/logging.sh" ]]; then echo "$dir/scripts/base/logging.sh"; return 0; fi + dir="$(dirname "$dir")" + done + return 1 +} + +LOGGING_PATH=$(find_logging_script "$SCRIPT_DIR") +if [[ -n "$LOGGING_PATH" ]]; then + echo "✔ Found logging utils at: $LOGGING_PATH" + source "$LOGGING_PATH" +else + echo "❌ CRITICAL ERROR: Could not locate 'base/logging.sh'." + exit 1 +fi + +# ============================================================================== +# CPU THREAD LIMITS (Crucial for concurrency) +# ============================================================================== +export OMP_NUM_THREADS=1 +export MKL_NUM_THREADS=1 +export OPENBLAS_NUM_THREADS=1 +export VECLIB_MAXIMUM_THREADS=1 +export NUMEXPR_NUM_THREADS=1 + +# --- W&B Anti-Hang Protections --- +export WANDB_START_METHOD="thread" # Prevents multiprocessing deadlocks on exit +export WANDB__SERVICE_WAIT=300 # Forces daemon to timeout after 5 mins if stuck +# ============================================================================== +# SECTION 2: HARDWARE & CONCURRENCY (Auto-Detected) +# ============================================================================== + +# 2.1 Auto-detect GPUs and determine jobs-per-GPU from VRAM. +# Output format: "JOBS_PER_GPU gpu_id_0 gpu_id_1 ..." +# Thresholds: >= 80 GB -> 4 jobs, <= 30 GB -> 2 jobs, between -> 3 jobs. + +_gpu_info=$(python3 -c " +import subprocess +import os + +# 1. Read the allowed GPUs from the environment variable +selected_env = os.environ.get('SELECTED_GPUS', '').strip() +allowed_gpus = [x.strip() for x in selected_env.split(',')] if selected_env else None + +try: + out = subprocess.check_output( + ['nvidia-smi', '--query-gpu=index,memory.total', '--format=csv,noheader,nounits'], + text=True + ) + indices, mem_mb = [], [] + for line in out.strip().splitlines(): + idx, mem = line.split(',') + idx = idx.strip() + + # 2. Skip this GPU if it's not in our selected list + if allowed_gpus and idx not in allowed_gpus: + continue + + indices.append(idx) + mem_mb.append(int(mem.strip())) + + # Safety check in case the selected GPUs don't exist + if not indices: + print('0') + exit(0) + + min_mem_gb = min(mem_mb) / 1024 + if min_mem_gb >= 80: + jobs = 4 + elif min_mem_gb <= 10: + jobs = 1 + elif min_mem_gb <= 30: + jobs = 2 + else: + jobs = 3 + + print(jobs, ' '.join(indices)) +except Exception: + print('2 0') +") +read -r JOBS_PER_GPU _gpu_ids <<< "$_gpu_info" +read -ra physical_gpus <<< "$_gpu_ids" + +echo "✔ Detected ${#physical_gpus[@]} GPU(s): ${physical_gpus[*]}" +echo "✔ Jobs per GPU: $JOBS_PER_GPU" + +# 2.2 Create Virtual Slots +gpus=() +for gpu in "${physical_gpus[@]}"; do + for ((i=1; i<=JOBS_PER_GPU; i++)); do gpus+=("$gpu"); done +done +echo "✔ Total virtual slots: ${#gpus[@]}" + +# 2.3 Initialize Slot Tracking +declare -a slot_pids +for i in "${!gpus[@]}"; do slot_pids[$i]=0; done + + +# ============================================================================== +# SECTION 3: EXPERIMENT PARAMETERS +# ============================================================================== + +# --- Models (both domains) --- +# Use "alias::hydra_value" to disambiguate run names (both share basename "hopse_g"). +models=( + "cell_hopse_g::cell/hopse_g" + "sim_hopse_g::simplicial/hopse_g" +) +gpse_models=( + "molpcba" + "zinc" + # "pcqm4mv2" + # "geom" + # "chembl" +) + +# --- Datasets --- +datasets=( + # "graph/MUTAG" + # "graph/cocitation_cora" + # "graph/PROTEINS" + # "graph/NCI1" + # "graph/NCI109" + # "graph/cocitation_citeseer" + # "graph/cocitation_pubmed" + "graph/BBB_Martins" + "graph/Caco2_Wang" + "graph/Clearance_Hepatocyte_AZ" + "graph/CYP3A4_Veith" + "simplicial/mantra_name" + "simplicial/mantra_orientation" + "simplicial/mantra_betti_numbers" + "graph/ZINC" +) + +# --- Neighborhoods (8 configs from the original HOPSE study) --- +# Use "alias::hydra_value" format for readable run names. +neighborhoods=( + "adj1::[up_adjacency-0]" + "adj2::[up_adjacency-0,2-up_adjacency-0]" + "adj3::[up_adjacency-0,up_adjacency-1,2-up_adjacency-0,down_adjacency-1,down_adjacency-2,2-down_adjacency-2]" + "inc1::[up_incidence-0,2-up_incidence-0]" + "inc2::[up_incidence-0,up_incidence-1,2-up_incidence-0,down_incidence-1,down_incidence-2,2-down_incidence-2]" +) + +# --- Hyperparameters (superset across all dataset groups) --- +num_layers=(1 2 4) +hidden_channels=(128 256) +proj_dropouts=(0.25 0.5) +lrs=(0.01 0.001) +weight_decays=(0.0001) +batch_sizes=(128 256) +DATA_SEEDS=(0 3 5 7 9) + +# --- Fixed Parameters --- +FIXED_ARGS=( + "trainer.max_epochs=500" + "trainer.min_epochs=50" + "trainer.check_val_every_n_epoch=5" + "callbacks.early_stopping.patience=10" + "delete_checkpoint_after_test=True" + "+combined_feature_encodings.preprocessor_device='cuda'" +) + + +# ============================================================================== +# SECTION 4: SWEEP CONFIGURATION MAPPING (CRITICAL ORDERING) +# Format: "ShortTag | HydraKey | ${Array[*]}" +# +# Values support an optional "alias::hydra_value" syntax for readable names. +# The generator also filters out invalid model+dataset combos. +# ============================================================================== + +SWEEP_CONFIG=( + # --- LEVEL 1: SLOWEST CHANGING (Outer Loops) --- + "|model|${models[*]}" + "|dataset|${datasets[*]}" + "N|model.preprocessing_params.neighborhoods|${neighborhoods[*]}" + "|transforms.hopse_encoding.pretrain_model|${gpse_models[*]}" + + # --- LEVEL 2: HYPERPARAMETERS --- + "L|model.backbone.n_layers|${num_layers[*]}" + "h|model.feature_encoder.out_channels|${hidden_channels[*]}" + "pdro|model.feature_encoder.proj_dropout|${proj_dropouts[*]}" + "lr|optimizer.parameters.lr|${lrs[*]}" + "wd|optimizer.parameters.weight_decay|${weight_decays[*]}" + "bs|dataset.dataloader_params.batch_size|${batch_sizes[*]}" + + # --- LEVEL 3: FASTEST CHANGING (Inner Loop) --- + "seed|dataset.split_params.data_seed|${DATA_SEEDS[*]}" +) + + +# ============================================================================== +# SECTION 5: PYTHON GENERATOR (Smart Transductive Filtering) +# ============================================================================== + +# Define where your dataset YAMLs live so the generator can inspect them. +# UPDATE THIS PATH IF YOUR CONFIGS ARE STORED ELSEWHERE. +export CONFIG_DIR="./configs/dataset" + +generate_combinations() { +python3 -c " +import sys, itertools, os + +config_dir = os.environ.get('CONFIG_DIR', './configs/dataset') + +# 1. Parse Input Specs +specs = [] +for item in sys.argv[1:]: + parts = item.split('|') + tag = parts[0].strip() + key = parts[1].strip() + vals = parts[2].split() + specs.append({'tag': tag, 'key': key, 'vals': vals}) + +# 2. Generate Cartesian Product +options = [[(s['tag'], s['key'], val) for val in s['vals']] for s in specs] +combinations = list(itertools.product(*options)) + +# Helper to strip alias +def hydra_val(v): + return v.split('::', 1)[1] if '::' in v else v + +# Find the first batch size in the sweep so we don't duplicate transductive runs +bs_key = 'dataset.dataloader_params.batch_size' +bs_spec = next((s for s in specs if s['key'] == bs_key), None) +first_bs = hydra_val(bs_spec['vals'][0]) if bs_spec else None + +# 3. Filter and Mutate Combos +valid = [] +skipped = 0 +transductive_cache = {} + +for combo in combinations: + vals_dict = {key: hydra_val(val) for (_, key, val) in combo} + model_val = vals_dict.get('model', '') + dataset_val = vals_dict.get('dataset', '') + current_bs = vals_dict.get(bs_key, '') + + # --- Rule A: Skip cell model + simplicial dataset --- + if model_val.startswith('cell/') and dataset_val.startswith('simplicial/'): + skipped += 1 + continue + + # --- Rule B: Transductive Batch Size Handler --- + is_transductive = False + if dataset_val in transductive_cache: + is_transductive = transductive_cache[dataset_val] + else: + # Construct path to yaml (e.g., ./configs/dataset/graph/cocitation_cora.yaml) + yaml_path = os.path.join(config_dir, f'{dataset_val}.yaml') + if os.path.exists(yaml_path): + with open(yaml_path, 'r') as f: + # Fast text check avoids needing pip install pyyaml + if 'learning_setting: transductive' in f.read(): + is_transductive = True + else: + print(f'⚠️ WARNING: Could not find config at {yaml_path}', file=sys.stderr) + + transductive_cache[dataset_val] = is_transductive + + if is_transductive: + # If this isn't the first batch size in the sweep list, skip it + # to avoid running the exact same bs=1 experiment multiple times. + if current_bs != first_bs: + skipped += 1 + continue + + # Mutate the current combination to force batch_size to 1 + new_combo = [] + for (tag, key, val) in combo: + if key == bs_key: + # Force the value to 1. If an alias was used, keep it clean. + new_combo.append((tag, key, '1')) + else: + new_combo.append((tag, key, val)) + combo = tuple(new_combo) + + valid.append(combo) + +# 4. Print header +print(f'TOTAL;{len(valid)}') +if skipped: + print(f'SKIPPED;{skipped}', file=sys.stderr) + +# 5. Print each valid combination +for combo in valid: + name_parts = [] + cmd_args = [] + for (tag, key, val) in combo: + if '::' in val: + alias, hydra_val_str = val.split('::', 1) + clean_val = alias + actual_val = hydra_val_str + else: + clean_val = os.path.basename(val) + actual_val = val + + if tag: + name_parts.append(f'{tag}{clean_val}') + else: + name_parts.append(clean_val) + cmd_args.append(f'{key}={actual_val}') + + run_name = '_'.join(name_parts) + print(f'{run_name};' + ' '.join(cmd_args)) +" "${SWEEP_CONFIG[@]}" +} + +# ============================================================================== +# SECTION 5.5: RESUME — LOAD COMPLETED RUNS +# ============================================================================== + +declare -A _completed_runs +if [[ "$RESUME" == "true" ]]; then + # run_and_log nests: $LOG_DIR/$log_group/SUCCESSFUL_RUNS.log + _success_log="$LOG_DIR/$log_group/SUCCESSFUL_RUNS.log" + if [[ -f "$_success_log" ]]; then + while IFS= read -r _line; do + # Format: "DATE: [SUCCESS] run_name" + _rname="${_line##*\[SUCCESS\] }" + _completed_runs["$_rname"]=1 + done < "$_success_log" + echo "✔ Loaded ${#_completed_runs[@]} completed runs to skip." + else + echo "⚠️ No SUCCESSFUL_RUNS.log found at $_success_log — nothing to skip." + fi +fi + +# ============================================================================== +# SECTION 6: MAIN EXECUTION LOOP +# ============================================================================== + +echo "----------------------------------------------------------" +echo " Generating experiment combinations..." +echo "----------------------------------------------------------" + +total_runs=0 +run_counter=0 +skipped_completed=0 +one_percent_step=1 + +while IFS=";" read -r col1 col2; do + + # 6.1 Handle Header + if [[ "$col1" == "TOTAL" ]]; then + total_runs=$col2 + if [ "$total_runs" -gt 0 ]; then + one_percent_step=$(( total_runs / 100 )) + fi + if [ "$one_percent_step" -eq 0 ]; then one_percent_step=1; fi + + echo "► Total runs planned: $total_runs" + echo "► Reporting progress every $one_percent_step runs (1%)" + echo "----------------------------------------------------------" + continue + fi + + # 6.2 Parse Run Data + run_name="$col1" + dynamic_args_str="$col2" + + # 6.2.1 Skip if already completed (RESUME mode) + if [[ "$RESUME" == "true" && -n "${_completed_runs[$run_name]+x}" ]]; then + ((skipped_completed++)) + continue + fi + + # 6.3 Update Progress + ((run_counter++)) + if (( run_counter % one_percent_step == 0 )); then + if [ "$total_runs" -gt 0 ]; then + percent=$(( (run_counter * 100) / total_runs )) + else + percent=0 + fi + echo "📊 Progress: ${percent}% completed ($run_counter / $total_runs runs launched)" + fi + + # 6.4 Find a Free GPU Slot + assigned_slot=-1 + while [ "$assigned_slot" -eq -1 ]; do + for i in "${!gpus[@]}"; do + pid="${slot_pids[$i]}" + if [ "$pid" -eq 0 ] || ! kill -0 "$pid" 2>/dev/null; then + assigned_slot=$i + break + fi + done + if [ "$assigned_slot" -eq -1 ]; then + wait -n + fi + done + + # 6.5 Prepare Command + current_gpu=${gpus[$assigned_slot]} + read -ra DYNAMIC_ARGS_ARRAY <<< "$dynamic_args_str" + + # --- Extract dataset name for dynamic W&B project --- + dataset_val="" + for arg in "${DYNAMIC_ARGS_ARRAY[@]}"; do + if [[ $arg == dataset=* ]]; then + dataset_full_path="${arg#*=}" + dataset_val=$(basename "$dataset_full_path") + break + fi + done + dynamic_project_name="${project_name}_${dataset_val}" + + cmd=( + "python" "-m" "topobench" + "${DYNAMIC_ARGS_ARRAY[@]}" + "${FIXED_ARGS[@]}" + "trainer.devices=[${current_gpu}]" + "+logger.wandb.entity=${wandb_entity}" + "logger.wandb.project=${dynamic_project_name}" + "+logger.wandb.name=${run_name}" + ) + + # 6.6 Execute + run_and_log "${cmd[*]}" "$log_group" "$run_name" "$LOG_DIR" & + slot_pids[$assigned_slot]=$! + +done < <(generate_combinations) + + +# ============================================================================== +# SECTION 7: CLEANUP +# ============================================================================== +echo "----------------------------------------------------------" +echo " All jobs launched ($run_counter total, $skipped_completed skipped as already completed)." +echo " Waiting for remaining background jobs to finish..." +echo "----------------------------------------------------------" +wait +echo "✔ All runs complete." diff --git a/scripts/hopse_m.sh b/scripts/hopse_m.sh index 4aaef0e21..6c609c38f 100755 --- a/scripts/hopse_m.sh +++ b/scripts/hopse_m.sh @@ -9,6 +9,11 @@ # - ORDERING: Prioritizes running all seeds for a config before moving on. # - FILTERING: Skips invalid model+dataset combos (cell + simplicial data). # ============================================================================== +# DO NOT MISS THIS + +export SELECTED_GPUS="0,1,2,3,4,5,6,7" +wandb_entity="gbg141-hopse" +RESUME=true # Set to true to skip already-completed runs (reads SUCCESSFUL_RUNS.log) # ============================================================================== # SECTION 1: LOGGING & ENVIRONMENT SETUP @@ -24,9 +29,14 @@ echo "==========================================================" echo " Preparing log directory: $LOG_DIR" echo "==========================================================" -# 1.2 Clean up old logs to ensure a fresh run -if [ -d "$LOG_DIR" ]; then rm -r "$LOG_DIR"; fi -mkdir -p "$LOG_DIR" +# 1.2 Log directory management +if [[ "$RESUME" == "true" ]]; then + echo "⏩ RESUME MODE: Keeping existing logs." + mkdir -p "$LOG_DIR" +else + if [ -d "$LOG_DIR" ]; then rm -r "$LOG_DIR"; fi + mkdir -p "$LOG_DIR" +fi # 1.3 Robust Dependency Loading SCRIPT_DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" &> /dev/null && pwd )" @@ -51,7 +61,14 @@ else exit 1 fi - +# ============================================================================== +# CPU THREAD LIMITS (Crucial for concurrency) +# ============================================================================== +export OMP_NUM_THREADS=1 +export MKL_NUM_THREADS=1 +export OPENBLAS_NUM_THREADS=1 +export VECLIB_MAXIMUM_THREADS=1 +export NUMEXPR_NUM_THREADS=1 # ============================================================================== # SECTION 2: HARDWARE & CONCURRENCY (Auto-Detected) # ============================================================================== @@ -61,6 +78,12 @@ fi # Thresholds: >= 80 GB -> 4 jobs, <= 30 GB -> 2 jobs, between -> 3 jobs. _gpu_info=$(python3 -c " import subprocess +import os + +# 1. Read the allowed GPUs from the environment variable +selected_env = os.environ.get('SELECTED_GPUS', '').strip() +allowed_gpus = [x.strip() for x in selected_env.split(',')] if selected_env else None + try: out = subprocess.check_output( ['nvidia-smi', '--query-gpu=index,memory.total', '--format=csv,noheader,nounits'], @@ -69,15 +92,30 @@ try: indices, mem_mb = [], [] for line in out.strip().splitlines(): idx, mem = line.split(',') - indices.append(idx.strip()) + idx = idx.strip() + + # 2. Skip this GPU if it's not in our selected list + if allowed_gpus and idx not in allowed_gpus: + continue + + indices.append(idx) mem_mb.append(int(mem.strip())) + + # Safety check in case the selected GPUs don't exist + if not indices: + print('0') + exit(0) + min_mem_gb = min(mem_mb) / 1024 if min_mem_gb >= 80: - jobs = 4 + jobs = 5 + elif min_mem_gb <= 10: + jobs = 1 elif min_mem_gb <= 30: jobs = 2 else: jobs = 3 + print(jobs, ' '.join(indices)) except Exception: print('2 0') @@ -113,14 +151,14 @@ models=( # --- Datasets --- datasets=( - "graph/MUTAG" - "graph/cocitation_cora" - "graph/PROTEINS" - "graph/NCI1" - "graph/NCI109" - "graph/ZINC" - "graph/cocitation_citeseer" - "graph/cocitation_pubmed" + # "graph/MUTAG" + # "graph/cocitation_cora" + # "graph/PROTEINS" + # "graph/NCI1" + # "graph/NCI109" + # "graph/ZINC" + # "graph/cocitation_citeseer" + # "graph/cocitation_pubmed" "simplicial/mantra_name" "simplicial/mantra_orientation" "simplicial/mantra_betti_numbers" @@ -147,7 +185,7 @@ num_layers=(1 2 4) hidden_channels=(128 256) proj_dropouts=(0.25 0.5) lrs=(0.01 0.001) -weight_decays=(0 0.0001) +weight_decays=(0.0001) batch_sizes=(128 256) DATA_SEEDS=(0 3 5 7 9) @@ -157,6 +195,8 @@ FIXED_ARGS=( "trainer.min_epochs=50" "trainer.check_val_every_n_epoch=5" "callbacks.early_stopping.patience=10" + "delete_checkpoint_after_test=True" + "+combined_feature_encodings.preprocessor_device='cuda'" ) @@ -305,6 +345,26 @@ for combo in valid: " "${SWEEP_CONFIG[@]}" } +# ============================================================================== +# SECTION 5.5: RESUME — LOAD COMPLETED RUNS +# ============================================================================== + +declare -A _completed_runs +if [[ "$RESUME" == "true" ]]; then + # run_and_log nests: $LOG_DIR/$log_group/SUCCESSFUL_RUNS.log + _success_log="$LOG_DIR/$log_group/SUCCESSFUL_RUNS.log" + if [[ -f "$_success_log" ]]; then + while IFS= read -r _line; do + # Format: "DATE: [SUCCESS] run_name" + _rname="${_line##*\[SUCCESS\] }" + _completed_runs["$_rname"]=1 + done < "$_success_log" + echo "✔ Loaded ${#_completed_runs[@]} completed runs to skip." + else + echo "⚠️ No SUCCESSFUL_RUNS.log found at $_success_log — nothing to skip." + fi +fi + # ============================================================================== # SECTION 6: MAIN EXECUTION LOOP # ============================================================================== @@ -315,6 +375,7 @@ echo "----------------------------------------------------------" total_runs=0 run_counter=0 +skipped_completed=0 one_percent_step=1 while IFS=";" read -r col1 col2; do @@ -337,6 +398,12 @@ while IFS=";" read -r col1 col2; do run_name="$col1" dynamic_args_str="$col2" + # 6.2.1 Skip if already completed (RESUME mode) + if [[ "$RESUME" == "true" && -n "${_completed_runs[$run_name]+x}" ]]; then + ((skipped_completed++)) + continue + fi + # 6.3 Update Progress ((run_counter++)) if (( run_counter % one_percent_step == 0 )); then @@ -383,7 +450,9 @@ while IFS=";" read -r col1 col2; do "${DYNAMIC_ARGS_ARRAY[@]}" "${FIXED_ARGS[@]}" "trainer.devices=[${current_gpu}]" + "+logger.wandb.entity=${wandb_entity}" "logger.wandb.project=${dynamic_project_name}" + "+logger.wandb.name=${run_name}" ) # 6.6 Execute @@ -397,7 +466,7 @@ done < <(generate_combinations) # SECTION 7: CLEANUP # ============================================================================== echo "----------------------------------------------------------" -echo " All jobs launched ($run_counter total)." +echo " All jobs launched ($run_counter total, $skipped_completed skipped as already completed)." echo " Waiting for remaining background jobs to finish..." echo "----------------------------------------------------------" wait diff --git a/scripts/hopse_m_ablation.sh b/scripts/hopse_m_ablation.sh new file mode 100755 index 000000000..944ff0f72 --- /dev/null +++ b/scripts/hopse_m_ablation.sh @@ -0,0 +1,484 @@ +#!/bin/bash +# ============================================================================== +# SCRIPT: hopse_m_ablation.sh +# DESCRIPTION: +# Runs a scalable hyperparameter sweep for HOPSE_M models across both +# simplicial and cellular domains. +# - ARCHITECTURE: Uses a "Cartesian Product" generation strategy. +# - CONCURRENCY: Uses "Virtual Slots" to run N jobs per GPU. +# - ORDERING: Prioritizes running all seeds for a config before moving on. +# - FILTERING: Skips invalid model+dataset combos (cell + simplicial data). +# ============================================================================== +# DO NOT MISS THIS + +export SELECTED_GPUS="0,1,2,3,4,5,6,7" +wandb_entity="gbg141-hopse" +RESUME=true # Set to true to skip already-completed runs (reads SUCCESSFUL_RUNS.log) + +# ============================================================================== +# SECTION 1: LOGGING & ENVIRONMENT SETUP +# ============================================================================== + +# 1.1 Define Project Identifiers +script_name="$(basename "${BASH_SOURCE[0]}" .sh)" +project_name="${script_name}" +log_group="hopse_m_ablation_sweep" +LOG_DIR="./logs/${log_group}" + +echo "==========================================================" +echo " Preparing log directory: $LOG_DIR" +echo "==========================================================" + +# 1.2 Log directory management +if [[ "$RESUME" == "true" ]]; then + echo "⏩ RESUME MODE: Keeping existing logs." + mkdir -p "$LOG_DIR" +else + if [ -d "$LOG_DIR" ]; then rm -r "$LOG_DIR"; fi + mkdir -p "$LOG_DIR" +fi + +# 1.3 Robust Dependency Loading +SCRIPT_DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" &> /dev/null && pwd )" +export HYDRA_FULL_ERROR=1 + +find_logging_script() { + local dir="$1" + while [[ "$dir" != "/" ]]; do + if [[ -f "$dir/base/logging.sh" ]]; then echo "$dir/base/logging.sh"; return 0; fi + if [[ -f "$dir/scripts/base/logging.sh" ]]; then echo "$dir/scripts/base/logging.sh"; return 0; fi + dir="$(dirname "$dir")" + done + return 1 +} + +LOGGING_PATH=$(find_logging_script "$SCRIPT_DIR") +if [[ -n "$LOGGING_PATH" ]]; then + echo "✔ Found logging utils at: $LOGGING_PATH" + source "$LOGGING_PATH" +else + echo "❌ CRITICAL ERROR: Could not locate 'base/logging.sh'." + exit 1 +fi + +# ============================================================================== +# CPU THREAD LIMITS (Crucial for concurrency) +# ============================================================================== +export OMP_NUM_THREADS=1 +export MKL_NUM_THREADS=1 +export OPENBLAS_NUM_THREADS=1 +export VECLIB_MAXIMUM_THREADS=1 +export NUMEXPR_NUM_THREADS=1 +# ============================================================================== +# SECTION 2: HARDWARE & CONCURRENCY (Auto-Detected) +# ============================================================================== + +# 2.1 Auto-detect GPUs and determine jobs-per-GPU from VRAM. +# Output format: "JOBS_PER_GPU gpu_id_0 gpu_id_1 ..." +# Thresholds: >= 80 GB -> 4 jobs, <= 30 GB -> 2 jobs, between -> 3 jobs. +_gpu_info=$(python3 -c " +import subprocess +import os + +# 1. Read the allowed GPUs from the environment variable +selected_env = os.environ.get('SELECTED_GPUS', '').strip() +allowed_gpus = [x.strip() for x in selected_env.split(',')] if selected_env else None + +try: + out = subprocess.check_output( + ['nvidia-smi', '--query-gpu=index,memory.total', '--format=csv,noheader,nounits'], + text=True + ) + indices, mem_mb = [], [] + for line in out.strip().splitlines(): + idx, mem = line.split(',') + idx = idx.strip() + + # 2. Skip this GPU if it's not in our selected list + if allowed_gpus and idx not in allowed_gpus: + continue + + indices.append(idx) + mem_mb.append(int(mem.strip())) + + # Safety check in case the selected GPUs don't exist + if not indices: + print('0') + exit(0) + + min_mem_gb = min(mem_mb) / 1024 + if min_mem_gb >= 80: + jobs = 5 + elif min_mem_gb <= 10: + jobs = 1 + elif min_mem_gb <= 30: + jobs = 4 + else: + jobs = 3 + + print(jobs, ' '.join(indices)) +except Exception: + print('2 0') +") +read -r JOBS_PER_GPU _gpu_ids <<< "$_gpu_info" +read -ra physical_gpus <<< "$_gpu_ids" + +echo "✔ Detected ${#physical_gpus[@]} GPU(s): ${physical_gpus[*]}" +echo "✔ Jobs per GPU: $JOBS_PER_GPU" + +# 2.2 Create Virtual Slots +gpus=() +for gpu in "${physical_gpus[@]}"; do + for ((i=1; i<=JOBS_PER_GPU; i++)); do gpus+=("$gpu"); done +done +echo "✔ Total virtual slots: ${#gpus[@]}" + +# 2.3 Initialize Slot Tracking +declare -a slot_pids +for i in "${!gpus[@]}"; do slot_pids[$i]=0; done + + +# ============================================================================== +# SECTION 3: EXPERIMENT PARAMETERS +# ============================================================================== + +# --- Models (both domains) --- +# Use "alias::hydra_value" to disambiguate run names (both share basename "hopse_m"). +models=( + "sim_hopse_m::simplicial/hopse_m" + # "cell_hopse_m::cell/hopse_m" +) + +# --- Datasets --- +datasets=( + "graph/MUTAG" + "graph/PROTEINS" + "graph/NCI1" + "graph/NCI109" + "graph/BBB_Martins" + "graph/Caco2_Wang" + "graph/Clearance_Hepatocyte_AZ" + "graph/CYP3A4_Veith" + "simplicial/mantra_name" + "simplicial/mantra_orientation" + "simplicial/mantra_betti_numbers" + "graph/cocitation_cora" + "graph/cocitation_citeseer" + "graph/cocitation_pubmed" + "graph/ZINC" +) + +# --- Neighborhoods (8 configs from the original HOPSE study) --- +# Use "alias::hydra_value" format for readable run names. +neighborhoods=( + "adj1::[up_adjacency-0]" + # "adj2::[up_adjacency-0,2-up_adjacency-0]" + "adj3::[up_adjacency-0,up_adjacency-1,2-up_adjacency-0,down_adjacency-1,down_adjacency-2,2-down_adjacency-2]" + # "inc1::[up_incidence-0,2-up_incidence-0]" + "inc2::[up_incidence-0,up_incidence-1,2-up_incidence-0,down_incidence-1,down_incidence-2,2-down_incidence-2]" +) + +# --- Encodings (two families to compare) --- +encodings=( + "lappe::[LapPE]" + "rwse::[RWSE]" + "electro::[ElectrostaticPE]" + "hkdiag::[HKdiagSE]" + "hkfe::[HKFE]" + "khopfe::[KHopFE]" + "pprfe::[PPRFE]" + # "pse::[LapPE,RWSE,ElectrostaticPE,HKdiagSE]" + # "fe::[HKFE,KHopFE,PPRFE]" +) + +# --- Hyperparameters (superset across all dataset groups) --- +num_layers=(1 2) +hidden_channels=(128) +proj_dropouts=(0.25) +lrs=(0.001) +weight_decays=(0.0001) +batch_sizes=(128) +DATA_SEEDS=(0 3 5 7 9) + +# --- Fixed Parameters --- +FIXED_ARGS=( + "trainer.max_epochs=500" + "trainer.min_epochs=50" + "trainer.check_val_every_n_epoch=5" + "callbacks.early_stopping.patience=20" + "delete_checkpoint_after_test=True" + "+combined_feature_encodings.preprocessor_device='cuda'" +) + + +# ============================================================================== +# SECTION 4: SWEEP CONFIGURATION MAPPING (CRITICAL ORDERING) +# Format: "ShortTag | HydraKey | ${Array[*]}" +# +# Values support an optional "alias::hydra_value" syntax for readable names. +# The generator also filters out invalid model+dataset combos. +# ============================================================================== + +SWEEP_CONFIG=( + # --- LEVEL 1: SLOWEST CHANGING (Outer Loops) --- + "|model|${models[*]}" + "|dataset|${datasets[*]}" + "N|model.preprocessing_params.neighborhoods|${neighborhoods[*]}" + "enc|model.preprocessing_params.encodings|${encodings[*]}" + + # --- LEVEL 2: HYPERPARAMETERS --- + "L|model.backbone.n_layers|${num_layers[*]}" + "h|model.feature_encoder.out_channels|${hidden_channels[*]}" + "pdro|model.feature_encoder.proj_dropout|${proj_dropouts[*]}" + "lr|optimizer.parameters.lr|${lrs[*]}" + "wd|optimizer.parameters.weight_decay|${weight_decays[*]}" + "bs|dataset.dataloader_params.batch_size|${batch_sizes[*]}" + + # --- LEVEL 3: FASTEST CHANGING (Inner Loop) --- + "seed|dataset.split_params.data_seed|${DATA_SEEDS[*]}" +) + + +# ============================================================================== +# SECTION 5: PYTHON GENERATOR (Smart Transductive Filtering) +# ============================================================================== + +# Define where your dataset YAMLs live so the generator can inspect them. +# UPDATE THIS PATH IF YOUR CONFIGS ARE STORED ELSEWHERE. +export CONFIG_DIR="./configs/dataset" + +generate_combinations() { +python3 -c " +import sys, itertools, os + +config_dir = os.environ.get('CONFIG_DIR', './configs/dataset') + +# 1. Parse Input Specs +specs = [] +for item in sys.argv[1:]: + parts = item.split('|') + tag = parts[0].strip() + key = parts[1].strip() + vals = parts[2].split() + specs.append({'tag': tag, 'key': key, 'vals': vals}) + +# 2. Generate Cartesian Product +options = [[(s['tag'], s['key'], val) for val in s['vals']] for s in specs] +combinations = list(itertools.product(*options)) + +# Helper to strip alias +def hydra_val(v): + return v.split('::', 1)[1] if '::' in v else v + +# Find the first batch size in the sweep so we don't duplicate transductive runs +bs_key = 'dataset.dataloader_params.batch_size' +bs_spec = next((s for s in specs if s['key'] == bs_key), None) +first_bs = hydra_val(bs_spec['vals'][0]) if bs_spec else None + +# 3. Filter and Mutate Combos +valid = [] +skipped = 0 +transductive_cache = {} + +for combo in combinations: + vals_dict = {key: hydra_val(val) for (_, key, val) in combo} + model_val = vals_dict.get('model', '') + dataset_val = vals_dict.get('dataset', '') + current_bs = vals_dict.get(bs_key, '') + + # --- Rule A: Skip cell model + simplicial dataset --- + if model_val.startswith('cell/') and dataset_val.startswith('simplicial/'): + skipped += 1 + continue + + # --- Rule B: Transductive Batch Size Handler --- + is_transductive = False + if dataset_val in transductive_cache: + is_transductive = transductive_cache[dataset_val] + else: + # Construct path to yaml (e.g., ./configs/dataset/graph/cocitation_cora.yaml) + yaml_path = os.path.join(config_dir, f'{dataset_val}.yaml') + if os.path.exists(yaml_path): + with open(yaml_path, 'r') as f: + # Fast text check avoids needing pip install pyyaml + if 'learning_setting: transductive' in f.read(): + is_transductive = True + else: + print(f'⚠️ WARNING: Could not find config at {yaml_path}', file=sys.stderr) + + transductive_cache[dataset_val] = is_transductive + + if is_transductive: + # If this isn't the first batch size in the sweep list, skip it + # to avoid running the exact same bs=1 experiment multiple times. + if current_bs != first_bs: + skipped += 1 + continue + + # Mutate the current combination to force batch_size to 1 + new_combo = [] + for (tag, key, val) in combo: + if key == bs_key: + # Force the value to 1. If an alias was used, keep it clean. + new_combo.append((tag, key, '1')) + else: + new_combo.append((tag, key, val)) + combo = tuple(new_combo) + + valid.append(combo) + +# 4. Print header +print(f'TOTAL;{len(valid)}') +if skipped: + print(f'SKIPPED;{skipped}', file=sys.stderr) + +# 5. Print each valid combination +for combo in valid: + name_parts = [] + cmd_args = [] + for (tag, key, val) in combo: + if '::' in val: + alias, hydra_val_str = val.split('::', 1) + clean_val = alias + actual_val = hydra_val_str + else: + clean_val = os.path.basename(val) + actual_val = val + + if tag: + name_parts.append(f'{tag}{clean_val}') + else: + name_parts.append(clean_val) + cmd_args.append(f'{key}={actual_val}') + + run_name = '_'.join(name_parts) + print(f'{run_name};' + ' '.join(cmd_args)) +" "${SWEEP_CONFIG[@]}" +} + +# ============================================================================== +# SECTION 5.5: RESUME — LOAD COMPLETED RUNS +# ============================================================================== + +declare -A _completed_runs +if [[ "$RESUME" == "true" ]]; then + # run_and_log nests: $LOG_DIR/$log_group/SUCCESSFUL_RUNS.log + _success_log="$LOG_DIR/$log_group/SUCCESSFUL_RUNS.log" + if [[ -f "$_success_log" ]]; then + while IFS= read -r _line; do + # Format: "DATE: [SUCCESS] run_name" + _rname="${_line##*\[SUCCESS\] }" + _completed_runs["$_rname"]=1 + done < "$_success_log" + echo "✔ Loaded ${#_completed_runs[@]} completed runs to skip." + else + echo "⚠️ No SUCCESSFUL_RUNS.log found at $_success_log — nothing to skip." + fi +fi + +# ============================================================================== +# SECTION 6: MAIN EXECUTION LOOP +# ============================================================================== + +echo "----------------------------------------------------------" +echo " Generating experiment combinations..." +echo "----------------------------------------------------------" + +total_runs=0 +run_counter=0 +skipped_completed=0 +one_percent_step=1 + +while IFS=";" read -r col1 col2; do + + # 6.1 Handle Header + if [[ "$col1" == "TOTAL" ]]; then + total_runs=$col2 + if [ "$total_runs" -gt 0 ]; then + one_percent_step=$(( total_runs / 100 )) + fi + if [ "$one_percent_step" -eq 0 ]; then one_percent_step=1; fi + + echo "► Total runs planned: $total_runs" + echo "► Reporting progress every $one_percent_step runs (1%)" + echo "----------------------------------------------------------" + continue + fi + + # 6.2 Parse Run Data + run_name="$col1" + dynamic_args_str="$col2" + + # 6.2.1 Skip if already completed (RESUME mode) + if [[ "$RESUME" == "true" && -n "${_completed_runs[$run_name]+x}" ]]; then + ((skipped_completed++)) + continue + fi + + # 6.3 Update Progress + ((run_counter++)) + if (( run_counter % one_percent_step == 0 )); then + if [ "$total_runs" -gt 0 ]; then + percent=$(( (run_counter * 100) / total_runs )) + else + percent=0 + fi + echo "📊 Progress: ${percent}% completed ($run_counter / $total_runs runs launched)" + fi + + # 6.4 Find a Free GPU Slot + assigned_slot=-1 + while [ "$assigned_slot" -eq -1 ]; do + for i in "${!gpus[@]}"; do + pid="${slot_pids[$i]}" + if [ "$pid" -eq 0 ] || ! kill -0 "$pid" 2>/dev/null; then + assigned_slot=$i + break + fi + done + if [ "$assigned_slot" -eq -1 ]; then + wait -n + fi + done + + # 6.5 Prepare Command + current_gpu=${gpus[$assigned_slot]} + read -ra DYNAMIC_ARGS_ARRAY <<< "$dynamic_args_str" + + # --- Extract dataset name for dynamic W&B project --- + dataset_val="" + for arg in "${DYNAMIC_ARGS_ARRAY[@]}"; do + if [[ $arg == dataset=* ]]; then + dataset_full_path="${arg#*=}" + dataset_val=$(basename "$dataset_full_path") + break + fi + done + dynamic_project_name="${project_name}_${dataset_val}" + + cmd=( + "python" "-m" "topobench" + "${DYNAMIC_ARGS_ARRAY[@]}" + "${FIXED_ARGS[@]}" + "trainer.devices=[${current_gpu}]" + "+logger.wandb.entity=${wandb_entity}" + "logger.wandb.project=${dynamic_project_name}" + "+logger.wandb.name=${run_name}" + ) + + # 6.6 Execute + run_and_log "${cmd[*]}" "$log_group" "$run_name" "$LOG_DIR" & + slot_pids[$assigned_slot]=$! + +done < <(generate_combinations) + + +# ============================================================================== +# SECTION 7: CLEANUP +# ============================================================================== +echo "----------------------------------------------------------" +echo " All jobs launched ($run_counter total, $skipped_completed skipped as already completed)." +echo " Waiting for remaining background jobs to finish..." +echo "----------------------------------------------------------" +wait +echo "✔ All runs complete." diff --git a/scripts/hopse_plotting/aggregator.py b/scripts/hopse_plotting/aggregator.py new file mode 100644 index 000000000..467f32759 --- /dev/null +++ b/scripts/hopse_plotting/aggregator.py @@ -0,0 +1,223 @@ +#!/usr/bin/env python3 +""" +Aggregate per-run W&B export rows across ``dataset.split_params.data_seed``. + +Reads per-run export CSV(s)—by default, every ``*.csv`` under +``csvs/hopse_experiments_wandb_export_shards`` when that folder exists and has +files; otherwise the monolithic ``csvs/hopse_experiments_wandb_export.csv``. +Several shard files are aggregated then concatenated into one ``-o`` CSV +(default: ``csvs/hopse_experiments_wandb_export_seed_agg.csv``). + +By default only hyperparameter groups with exactly ``--required-seeds`` raw +runs (after grouping on everything except the data seed) are written to the +output CSV; see the printed per-(model, dataset) distribution for other counts. + +Usage:: + + python scripts/hopse_plotting/aggregator.py + python scripts/hopse_plotting/aggregator.py -i path/to/export.csv -o path/to/agg.csv + python scripts/hopse_plotting/aggregator.py --input-dir scripts/hopse_plotting/csvs/hopse_experiments_wandb_export_shards + python scripts/hopse_plotting/aggregator.py --keep-incomplete-seeds + python scripts/hopse_plotting/aggregator.py --plot-seed-distributions + python scripts/hopse_plotting/aggregator.py --plot-seed-distributions --seed-dist-dir plots/seed_n +""" + +from __future__ import annotations + +import argparse +from pathlib import Path + +import pandas as pd + +from utils import ( + DEFAULT_AGGREGATED_EXPORT_CSV, + DEFAULT_WANDB_EXPORT_CSV, + DEFAULT_WANDB_EXPORT_SHARD_DIR, + PLOTS_DIR, + aggregate_many_wandb_export_csvs, + aggregate_wandb_export_csv, +) + +# Exact raw-run count per hyperparameter group required for a row to appear in ``-o``. +DEFAULT_REQUIRED_AGGREGATED_SEEDS = 5 + + +def _print_seed_bucket_report( + report, + *, + required_n_seeds: int | None, +) -> None: + if report.empty: + print("Seed-count distribution: (no aggregated hyperparameter groups).") + return + if required_n_seeds is not None: + print( + f"Seed-count distribution (hyperparameter groups per model+dataset); " + f"output CSV keeps only n_seeds=={required_n_seeds}." + ) + else: + print( + "Seed-count distribution (hyperparameter groups per model+dataset); " + "output CSV keeps all n_seeds (--keep-incomplete-seeds)." + ) + for (model, dataset), sub in report.groupby(["model", "dataset"], dropna=False): + print(f"\n model={model!r} dataset={dataset!r}") + sub_sorted = sub.sort_values("n_seeds") + for _, row in sub_sorted.iterrows(): + k = row["n_seeds"] + try: + k_int = int(k) if pd.notna(k) else k + except (TypeError, ValueError): + k_int = k + mark = ( + " <- rows written to -o" + if required_n_seeds is not None + and pd.notna(k) + and int(k) == int(required_n_seeds) + else "" + ) + print( + f" n_seeds={k_int}: {int(row['n_groups'])} groups " + f"({float(row['pct_of_groups']):.2f}% of groups for this pair){mark}" + ) + + +def _collect_input_paths( + *, + explicit: list[Path] | None, + input_dir: Path | None, + input_pattern: str, +) -> list[Path]: + paths: list[Path] = [] + if explicit: + paths.extend(explicit) + if input_dir is not None: + d = Path(input_dir) + if d.is_dir(): + paths.extend(sorted(d.glob(input_pattern))) + if not paths: + paths = [DEFAULT_WANDB_EXPORT_CSV] + seen: set[Path] = set() + uniq: list[Path] = [] + for p in paths: + rp = p.resolve() + if rp not in seen: + seen.add(rp) + uniq.append(p) + return uniq + + +def main() -> None: + p = argparse.ArgumentParser( + description="Aggregate W&B export CSV(s) over data seeds; always one combined -o CSV." + ) + p.add_argument( + "-i", + "--input", + action="append", + type=Path, + default=None, + metavar="PATH", + help=f"Per-run export CSV (repeat for multiple shards). If omitted, see --input-dir / default shard folder.", + ) + p.add_argument( + "--input-dir", + type=Path, + default=None, + help=( + "Aggregate every file matching --input-pattern under this directory. " + "If -i is not given and this is omitted, uses the shard folder when it " + f"contains CSVs, else {DEFAULT_WANDB_EXPORT_CSV}" + ), + ) + p.add_argument( + "--input-pattern", + default="*.csv", + help="Glob under --input-dir (default: *.csv)", + ) + p.add_argument( + "-o", + "--output", + type=Path, + default=DEFAULT_AGGREGATED_EXPORT_CSV, + help=f"Single combined seed-aggregated CSV (default: {DEFAULT_AGGREGATED_EXPORT_CSV})", + ) + p.add_argument( + "--required-seeds", + type=int, + default=DEFAULT_REQUIRED_AGGREGATED_SEEDS, + metavar="N", + help=( + "Only write hyperparameter groups aggregated from exactly this many " + f"raw runs (default: {DEFAULT_REQUIRED_AGGREGATED_SEEDS}). " + "Ignored with --keep-incomplete-seeds." + ), + ) + p.add_argument( + "--keep-incomplete-seeds", + action="store_true", + help=( + "Write all aggregated groups regardless of run count; still print the " + "per-(model,dataset) n_seeds distribution." + ), + ) + p.add_argument( + "--plot-seed-distributions", + action="store_true", + help=( + "After aggregating, write per-model PNGs (bar chart of n_seeds vs #groups per " + "dataset subplot) under --seed-dist-dir (see seed_n_distribution_plots.py)." + ), + ) + p.add_argument( + "--seed-dist-dir", + type=Path, + default=None, + help=f"Output directory for --plot-seed-distributions (default: {PLOTS_DIR}/seed_n_distributions).", + ) + args = p.parse_args() + if not args.keep_incomplete_seeds and int(args.required_seeds) < 1: + p.error("--required-seeds must be >= 1 unless --keep-incomplete-seeds is set.") + + input_dir = args.input_dir + if args.input is None and input_dir is None: + sd = DEFAULT_WANDB_EXPORT_SHARD_DIR + if sd.is_dir() and any(sd.glob(args.input_pattern)): + input_dir = sd + + paths = _collect_input_paths( + explicit=args.input, + input_dir=input_dir, + input_pattern=args.input_pattern, + ) + args.output.parent.mkdir(parents=True, exist_ok=True) + req = None if args.keep_incomplete_seeds else int(args.required_seeds) + + if len(paths) == 1: + agg, report = aggregate_wandb_export_csv( + paths[0], args.output, required_n_seeds=req + ) + print(f"Wrote {len(agg)} aggregated rows x {len(agg.columns)} columns -> {args.output}") + else: + agg, report = aggregate_many_wandb_export_csvs(paths, args.output, required_n_seeds=req) + print( + f"Combined {len(paths)} shard file(s) -> {len(agg)} aggregated rows x {len(agg.columns)} columns -> {args.output}" + ) + + _print_seed_bucket_report(report, required_n_seeds=req) + + if args.plot_seed_distributions: + from seed_n_distribution_plots import write_seed_distribution_plots + + dist_dir = args.seed_dist_dir or (PLOTS_DIR / "seed_n_distributions") + nfig = write_seed_distribution_plots( + report, + dist_dir, + required_n_seeds=req, + dpi=150, + ) + print(f"Seed n_seeds bar plots: wrote {nfig} figure(s) -> {dist_dir}") + + +if __name__ == "__main__": + main() diff --git a/scripts/hopse_plotting/best_rerun_sh_generator.py b/scripts/hopse_plotting/best_rerun_sh_generator.py new file mode 100644 index 000000000..a3446c75e --- /dev/null +++ b/scripts/hopse_plotting/best_rerun_sh_generator.py @@ -0,0 +1,628 @@ +#!/usr/bin/env python3 +""" +From a **seed-aggregated** W&B CSV, pick the best validation row per (model, dataset) +(same rule as ``collapse_aggregated_wandb_by_best_val`` / ``main_plot`` / ``table_generator``: +``utils.iter_best_val_group_picks``), then emit **two** bash scripts with the same Hydra commands: + +1. **Sequential** (default ``scripts/best_val_reruns_sequential.sh``): one ``python -m topobench`` + line after another (no GPU assignment; use ``--append-arg trainer.devices=[0]`` if needed). +2. **Parallel** (default ``scripts/best_val_reruns_parallel.sh``): same runs launched with + ``&``, ``trainer.devices=[GPU]`` round-robin over ``0..7`` by default (like + ``topotune/search_gccn_cell.sh``), then ``wait`` for all jobs. + +The aggregated export usually drops ``dataset.split_params.data_seed``; this script +appends ``dataset.split_params.data_seed=...`` for each seed in ``--data-seeds`` (default +the sweep set ``0,3,5,7,9``) so reruns match the original multi-seed protocol. + +Emitted ``.sh`` files use **LF** line endings only (``newline='\\n'``) so bash/WSL and Hydra +are not broken by Windows CRLF. + +Dataset overrides use Hydra config stems (e.g. ``graph/cocitation_cora``): loader +``data_name`` rows in the CSV (e.g. ``graph/Cora``) are rewritten using +``DATASET_LOADER_IDENTITY_TO_HYDRA`` in ``utils``. + +By default only **non-transductive** loader datasets are emitted: ``main_loader.DATASETS`` +minus ``graph/cocitation_{cora,citeseer,pubmed}``. Use ``--all-datasets`` to emit every +(model, dataset) group in the CSV. + +**Training defaults** mirror the sweep scripts (``gat.sh`` / ``gcn.sh`` / ``hopse_m.sh`` / +``topotune.sh`` / ``sann.sh`` / ``sccnn.sh`` / ``cwn.sh``): ``trainer.min_epochs=50``, +``trainer.check_val_every_n_epoch=5``, plus model-specific extras (``delete_checkpoint_after_test``, +HOPSE preprocessor device, and early-stopping patience 5 vs 10) when ``--fixed-args-profile auto`` +(default). TopoTune, SANN, SCCNN, and CWN share the same non-HOPSE block: patience 10 and +``delete_checkpoint_after_test=True``. Use ``--fixed-args-profile none`` to omit those extras +(not recommended for matching sweeps). + +Every command includes ``trainer.max_epochs`` (default 500) and ``callbacks.early_stopping.patience`` +(either ``--early-stopping-patience INT`` or, if omitted, 5 for ``graph/*{gin,gat,gcn}`` and 10 +for HOPSE / TopoTune / SANN / SCCNN / CWN under ``auto``). + +**Data seeds:** seed-aggregated CSVs drop ``dataset.split_params.data_seed``. By default this +script emits **one command per sweep seed** (``--data-seeds 0,3,5,7,9``) for each best-val +(model, dataset) row. Use ``--data-seeds 0`` to match the old single-seed behavior. With +``--keep-row-seed``, a per-run export that still has the seed column emits that seed only. + +Every rerun also sets ``deterministic=True`` (see ``configs/run.yaml``) so reviewers can +reproduce runs; override last with ``--append-arg deterministic=False`` if needed. + +W&B logging matches ``hopse_m.sh`` style: ``+logger.wandb.entity``, ``logger.wandb.project`` (same +project for every line by default: ``best_runs_rerun``), and ``+logger.wandb.name`` derived from +model/dataset (and seed when multiple). Disable with ``--no-wandb-logger`` or override via +``--append-arg`` (appended last). + +Further extras: ``--append-arg`` (e.g. ``trainer.devices=[0]``; later args override earlier). +On the **parallel** script, ``--append-arg trainer.devices=...`` overrides the round-robin GPU +for that slot (still appended last). + +**Hydra overrides from the winner row** use ``utils.hydra_overrides_from_aggregated_row`` with +``utils.CONFIG_PARAM_KEYS`` (same column contract as ``main_loader`` / seed aggregation). That +includes sweep axes that must be present for correct reruns, for example: + +- **HOPSE_G / GPSE** — ``transforms.hopse_encoding.pretrain_model`` (and related + ``transforms.hopse_encoding.*`` keys) so molpcba vs zinc checkpoints are not dropped. +- **SANN** — ``transforms.sann_encoding.*``, ``model.feature_encoder.selected_dimensions``, etc. + +Only non-empty cells become ``key=value`` flags; re-export from W&B after extending +``CONFIG_PARAM_KEYS`` so the seed-aggregated CSV carries those columns. + +Usage:: + + python scripts/hopse_plotting/best_rerun_sh_generator.py + python scripts/hopse_plotting/best_rerun_sh_generator.py -i scripts/hopse_plotting/csvs/hopse_experiments_wandb_export_seed_agg.csv \\ + -o scripts/best_val_reruns_sequential.sh \\ + --output-parallel scripts/best_val_reruns_parallel.sh + python scripts/hopse_plotting/best_rerun_sh_generator.py --parallel-gpus 0,1,2,3 + python scripts/hopse_plotting/best_rerun_sh_generator.py --data-seeds 0 + python scripts/hopse_plotting/best_rerun_sh_generator.py --no-parallel-script +""" + +from __future__ import annotations + +import argparse +import shlex +from pathlib import Path + +import pandas as pd + +from main_loader import DATASETS as LOADER_DATASETS +from utils import ( + CONFIG_PARAM_KEYS, + DEFAULT_AGGREGATED_EXPORT_CSV, + SEED_COLUMN, + aggregated_rows_best_validation_per_group, + hydra_dataset_key_from_loader_identity, + hydra_overrides_from_aggregated_row, + load_wandb_export_csv, + safe_filename_token, +) + +# Repo ``scripts/`` (parent of ``hopse_plotting/``) +_DEFAULT_SCRIPTS_DIR = Path(__file__).resolve().parent.parent +DEFAULT_EMIT_SH_SEQUENTIAL = _DEFAULT_SCRIPTS_DIR / "best_val_reruns_sequential.sh" +DEFAULT_EMIT_SH_PARALLEL = _DEFAULT_SCRIPTS_DIR / "best_val_reruns_parallel.sh" + +# Match sweep scripts ``trainer.max_epochs=500``. +DEFAULT_MAX_EPOCHS = 500 + +# Same order as ``DATA_SEEDS`` in ``gat.sh`` / ``hopse_m.sh`` / ``topotune.sh``. +DEFAULT_SWEEP_DATA_SEEDS = "0,3,5,7,9" + +# Match scripts/hopse_m.sh wandb_entity= / logger.wandb.project (single project for all reruns). +DEFAULT_WANDB_ENTITY = "gbg141-hopse" +DEFAULT_WANDB_PROJECT = "best_runs_rerun" + +# Same coverage as ``main_loader.DATASETS`` but drop Planetoid cocitation (transductive) configs. +_TRANSDUCTIVE_COCITATION_HYDRA: frozenset[str] = frozenset( + { + "graph/cocitation_cora", + "graph/cocitation_citeseer", + "graph/cocitation_pubmed", + } +) +DEFAULT_RERUN_ALLOWED_HYDRA_DATASETS: frozenset[str] = frozenset( + d for d in LOADER_DATASETS if d not in _TRANSDUCTIVE_COCITATION_HYDRA +) + +DEFAULT_PARALLEL_GPUS = "0,1,2,3,4,5,6,7" + + +def _sort_key_model_dataset(row) -> tuple[str, str]: + m = str(row.get("model", "")) + d = str(row.get("dataset", "")) + return (m, d) + + +def dataframe_filter_rerun_datasets( + df: pd.DataFrame, + *, + allowed_hydra: frozenset[str], +) -> pd.DataFrame: + """ + Keep rows whose ``dataset`` cell maps (via ``hydra_dataset_key_from_loader_identity``) + into ``allowed_hydra`` (e.g. loader list without cocitation cora/citeseer/pubmed). + """ + if "dataset" not in df.columns: + raise KeyError("CSV missing 'dataset' column") + + def canon(ds_val: object) -> str: + return hydra_dataset_key_from_loader_identity(str(ds_val).replace("\r", "").strip()) + + mask = df["dataset"].map(lambda v: canon(v) in allowed_hydra) + return df.loc[mask].copy() + + +def _parse_parallel_gpus(s: str) -> list[int]: + out: list[int] = [] + for part in str(s).replace("\r", "").split(","): + p = part.strip() + if p: + out.append(int(p)) + return out if out else [0] + + +def _parse_data_seeds(s: str) -> list[str]: + """Comma-separated ints -> string tokens for Hydra (``3`` not ``3.0``).""" + out: list[str] = [] + for part in str(s).replace("\r", "").split(","): + p = part.strip() + if not p: + continue + x = float(p) + out.append(str(int(x)) if x.is_integer() else p) + return out if out else ["0"] + + +def _resolve_fixed_args_profile(model: str, profile: str) -> str: + p = str(profile).replace("\r", "").strip().lower() + if p == "auto": + m = str(model).lower().replace("\r", "").strip() + if "hopse" in m: + return "hopse" + if "topotune" in m: + return "topotune" + # ``simplicial/sann``, ``cell/sann``, ``simplicial/sann_online``, … + if m.split("/")[-1].startswith("sann"): + return "sann" + # ``scripts/sccnn.sh`` — same FIXED_ARGS as TopoTune / ``scripts/cwn.sh``. + if m.split("/")[-1].startswith("sccnn"): + return "sccnn" + if m.split("/")[-1] == "cwn": + return "cwn" + if m.startswith("graph/gin") or m.startswith("graph/gat") or m.startswith("graph/gcn"): + return "graph" + if m.startswith("graph/"): + return "graph" + # Other models: TopoTune-style extras only — never HOPSE-only CUDA preprocessor. + return "topotune" + return p + + +def _benchmark_training_extras(profile: str) -> list[str]: + """Pieces of ``FIXED_ARGS`` from sweep scripts (excluding max_epochs / early stopping).""" + if profile in ("", "none"): + return [] + out = [ + "trainer.min_epochs=50", + "trainer.check_val_every_n_epoch=5", + ] + if profile == "hopse": + out.extend( + [ + "delete_checkpoint_after_test=True", + "+combined_feature_encodings.preprocessor_device='cuda'", + ] + ) + elif profile in ("topotune", "sann", "sccnn", "cwn"): + out.append("delete_checkpoint_after_test=True") + return out + + +def _default_early_stopping_patience(profile: str) -> int: + return 5 if profile == "graph" else 10 + + +def _row_data_seeds( + row, + *, + keep_row_seed: bool, + default_seeds: list[str], +) -> list[str]: + if not keep_row_seed: + return list(default_seeds) + if SEED_COLUMN not in row: + return list(default_seeds) + raw = row[SEED_COLUMN] + if pd.isna(raw): + return list(default_seeds) + s = str(raw).replace("\r", "").strip() + if s == "" or s.lower() in {"nan", "none"}: + return list(default_seeds) + x = float(s) + return [str(int(x)) if x.is_integer() else s] + + +def _base_hydra_parts_for_row( + row, + *, + skip_seed: set[str], + data_seed: str, + max_epochs: int, + early_stopping_patience: int | None, + fixed_args_profile: str, + wandb_entity: str | None, + wandb_project: str | None, + wandb_run_name: bool, + wandb_run_suffix: str, +) -> tuple[str, str, list[str]]: + """Hydra overrides for one winner row (no ``--append-arg`` extras, no ``trainer.devices``).""" + model = str(row.get("model", "")).replace("\r", "").strip() + dataset_raw = str(row.get("dataset", "")).replace("\r", "").strip() + dataset = hydra_dataset_key_from_loader_identity(dataset_raw) + resolved_profile = _resolve_fixed_args_profile(model, fixed_args_profile) + + parts = hydra_overrides_from_aggregated_row( + row, + config_keys=list(CONFIG_PARAM_KEYS), + skip_keys=skip_seed, + ) + parts.extend(_benchmark_training_extras(resolved_profile)) + if not any(p.startswith(f"{SEED_COLUMN}=") for p in parts): + parts.append(f"{SEED_COLUMN}={data_seed}") + parts.append(f"trainer.max_epochs={max_epochs}") + es = ( + early_stopping_patience + if early_stopping_patience is not None + else _default_early_stopping_patience(resolved_profile) + ) + parts.append(f"callbacks.early_stopping.patience={int(es)}") + parts.append("deterministic=True") + if wandb_entity and wandb_project: + parts.append(f"+logger.wandb.entity={wandb_entity}") + parts.append(f"logger.wandb.project={wandb_project}") + if wandb_run_name: + base_nm = f"{model.replace('/', '__')}__{dataset.replace('/', '__')}" + if wandb_run_suffix: + base_nm = f"{base_nm}{wandb_run_suffix}" + wname = safe_filename_token(base_nm, max_len=120) + parts.append(f"+logger.wandb.name={wname}") + return model, dataset, parts + + +def _sorted_winner_rows(df, *, group_cols: list[str]): + winners = aggregated_rows_best_validation_per_group(df, group_cols=group_cols) + if winners.empty: + raise ValueError("No rows after best-val selection (empty input?)") + rows = [winners.iloc[i] for i in range(len(winners))] + rows.sort(key=_sort_key_model_dataset) + return rows + + +def emit_sequential_rerun_script( + df, + *, + path: Path, + interpreter: str, + data_seeds: list[str], + append_args: list[str], + keep_row_seed: bool, + group_cols: list[str], + max_epochs: int, + early_stopping_patience: int | None, + fixed_args_profile: str, + wandb_entity: str | None, + wandb_project: str | None, + wandb_run_name: bool, +) -> int: + skip_seed = set() if keep_row_seed else {SEED_COLUMN} + rows = _sorted_winner_rows(df, group_cols=group_cols) + app = [a.replace("\r", "") for a in append_args] + + lines: list[str] = [ + "#!/usr/bin/env bash", + "# Auto-generated: best val per (model, dataset) — run commands one after another.", + "# Pair script: best_val_reruns_parallel.sh (GPUs in parallel, then wait).", + "", + ] + + n_cmd = 0 + for row in rows: + seeds = _row_data_seeds(row, keep_row_seed=keep_row_seed, default_seeds=data_seeds) + multi = len(seeds) > 1 + for data_seed in seeds: + suffix = f"__ds{data_seed}" if multi and wandb_run_name else "" + model, dataset, base = _base_hydra_parts_for_row( + row, + skip_seed=skip_seed, + data_seed=data_seed, + max_epochs=max_epochs, + early_stopping_patience=early_stopping_patience, + fixed_args_profile=fixed_args_profile, + wandb_entity=wandb_entity, + wandb_project=wandb_project, + wandb_run_name=wandb_run_name, + wandb_run_suffix=suffix, + ) + parts = list(base) + parts.extend(app) + cmd = shlex.join([interpreter, "-m", "topobench", *parts]) + seed_note = f" | data_seed={data_seed}" if multi else "" + lines.append(f"# {model} | {dataset}{seed_note}") + lines.append(cmd) + lines.append("") + n_cmd += 1 + + path = Path(path) + path.parent.mkdir(parents=True, exist_ok=True) + text = "\n".join(lines).rstrip() + "\n" + with path.open("w", encoding="utf-8", newline="\n") as f: + f.write(text) + try: + path.chmod(path.stat().st_mode | 0o111) + except OSError: + pass + return n_cmd + + +def emit_parallel_rerun_script( + df, + *, + path: Path, + interpreter: str, + data_seeds: list[str], + append_args: list[str], + keep_row_seed: bool, + group_cols: list[str], + max_epochs: int, + early_stopping_patience: int | None, + fixed_args_profile: str, + wandb_entity: str | None, + wandb_project: str | None, + wandb_run_name: bool, + gpu_ids: list[int], +) -> int: + skip_seed = set() if keep_row_seed else {SEED_COLUMN} + rows = _sorted_winner_rows(df, group_cols=group_cols) + app = [a.replace("\r", "") for a in append_args] + + gpu_bash_array = " ".join(str(g) for g in gpu_ids) + lines: list[str] = [ + "#!/usr/bin/env bash", + "# Auto-generated: same best-val reruns as best_val_reruns_sequential.sh, but launch in parallel.", + "# trainer.devices=[GPU] round-robins over GPUS; each job runs in background; wait at end.", + "", + "# Optional: match hopse_m.sh thread limits when many jobs share a machine", + "# export OMP_NUM_THREADS=1 MKL_NUM_THREADS=1 OPENBLAS_NUM_THREADS=1", + "", + f"GPUS=({gpu_bash_array})", + "_NUM_GPUS=${#GPUS[@]}", + "_i=0", + "", + ] + + n_cmd = 0 + for row in rows: + seeds = _row_data_seeds(row, keep_row_seed=keep_row_seed, default_seeds=data_seeds) + multi = len(seeds) > 1 + for data_seed in seeds: + suffix = f"__ds{data_seed}" if multi and wandb_run_name else "" + model, dataset, base = _base_hydra_parts_for_row( + row, + skip_seed=skip_seed, + data_seed=data_seed, + max_epochs=max_epochs, + early_stopping_patience=early_stopping_patience, + fixed_args_profile=fixed_args_profile, + wandb_entity=wandb_entity, + wandb_project=wandb_project, + wandb_run_name=wandb_run_name, + wandb_run_suffix=suffix, + ) + pre = shlex.join([interpreter, "-m", "topobench", *base]) + post = shlex.join(app) if app else "" + # Bash sets _gpu then Hydra sees trainer.devices=[0] style (variable expands inside [...]). + dev_fragment = r"trainer.devices=[${_gpu}]" + if post: + cmd_body = f"{pre} {dev_fragment} {post}" + else: + cmd_body = f"{pre} {dev_fragment}" + seed_note = f" | data_seed={data_seed}" if multi else "" + lines.append(f"# {model} | {dataset}{seed_note}") + lines.append('_gpu="${GPUS[$((_i % _NUM_GPUS))]}"; _i=$((_i + 1))') + lines.append(f"{cmd_body} &") + lines.append("") + n_cmd += 1 + + lines.append("wait") + lines.append('echo "All parallel reruns finished."') + + path = Path(path) + path.parent.mkdir(parents=True, exist_ok=True) + text = "\n".join(lines).rstrip() + "\n" + with path.open("w", encoding="utf-8", newline="\n") as f: + f.write(text) + try: + path.chmod(path.stat().st_mode | 0o111) + except OSError: + pass + return n_cmd + + +def main() -> None: + p = argparse.ArgumentParser( + description="Emit sequential + parallel bash scripts for best-val topobench reruns." + ) + p.add_argument( + "-i", + "--input", + type=Path, + default=DEFAULT_AGGREGATED_EXPORT_CSV, + help=f"Seed-aggregated CSV (default: {DEFAULT_AGGREGATED_EXPORT_CSV})", + ) + p.add_argument( + "-o", + "--output", + type=Path, + default=DEFAULT_EMIT_SH_SEQUENTIAL, + help=f"Sequential .sh path (default: {DEFAULT_EMIT_SH_SEQUENTIAL})", + ) + p.add_argument( + "--output-parallel", + type=Path, + default=DEFAULT_EMIT_SH_PARALLEL, + help=f"Parallel .sh path (default: {DEFAULT_EMIT_SH_PARALLEL})", + ) + p.add_argument( + "--no-parallel-script", + action="store_true", + help="Only write the sequential script.", + ) + p.add_argument( + "--parallel-gpus", + default=DEFAULT_PARALLEL_GPUS, + help=f"Comma-separated GPU indices for round-robin trainer.devices (default: {DEFAULT_PARALLEL_GPUS})", + ) + p.add_argument( + "--group-by", + metavar="COL", + nargs="+", + default=["model", "dataset"], + help="Group columns for best-val pick (default: model dataset)", + ) + p.add_argument( + "--interpreter", + default="python", + help="Python executable (default: python)", + ) + p.add_argument( + "--data-seeds", + default=DEFAULT_SWEEP_DATA_SEEDS, + help=( + "Comma-separated dataset.split_params.data_seed values; emits one command " + f"per best-val row per seed (default: {DEFAULT_SWEEP_DATA_SEEDS})" + ), + ) + p.add_argument( + "--max-epochs", + type=int, + default=DEFAULT_MAX_EPOCHS, + help=f"trainer.max_epochs=... for every command (default: {DEFAULT_MAX_EPOCHS})", + ) + p.add_argument( + "--early-stopping-patience", + type=int, + default=None, + metavar="N", + help=( + "callbacks.early_stopping.patience=...; if omitted, uses 5 for graph " + "gin/gat/gcn and 10 for HOPSE / TopoTune / SANN / SCCNN / CWN under the resolved " + "--fixed-args-profile (see --fixed-args-profile)." + ), + ) + p.add_argument( + "--fixed-args-profile", + choices=("auto", "graph", "hopse", "topotune", "sann", "sccnn", "cwn", "none"), + default="auto", + help=( + "Sweep-style extras after row overrides: min_epochs, check_val_every_n_epoch, " + "and model-specific flags (HOPSE: delete_checkpoint + preprocessor_device; " + "TopoTune / SANN / SCCNN / CWN: delete_checkpoint only). ``auto`` picks from the " + "model path (default: auto)." + ), + ) + p.add_argument( + "--append-arg", + action="append", + default=[], + metavar="KEY=VALUE", + help="Extra Hydra override appended after trainer/ES args (repeatable; overrides if key repeats)", + ) + p.add_argument( + "--keep-row-seed", + action="store_true", + help=( + "If the CSV still has dataset.split_params.data_seed (per-run export), emit only " + "that seed per row; otherwise use --data-seeds." + ), + ) + p.add_argument( + "--wandb-entity", + default=DEFAULT_WANDB_ENTITY, + help=f"W&B entity for every command (default: {DEFAULT_WANDB_ENTITY})", + ) + p.add_argument( + "--wandb-project", + default=DEFAULT_WANDB_PROJECT, + help=f"W&B project for every command (default: {DEFAULT_WANDB_PROJECT})", + ) + p.add_argument( + "--no-wandb-logger", + action="store_true", + help="Do not append logger.wandb entity/project/name overrides.", + ) + p.add_argument( + "--no-wandb-run-name", + action="store_true", + help="With W&B logging, omit +logger.wandb.name=... (entity and project still set).", + ) + p.add_argument( + "--all-datasets", + action="store_true", + help=( + "Do not restrict to main_loader datasets without cocitation trio; include every " + "dataset present in the CSV." + ), + ) + args = p.parse_args() + + wb_ent: str | None = None + wb_proj: str | None = None + if not args.no_wandb_logger: + wb_ent = str(args.wandb_entity).replace("\r", "").strip() + wb_proj = str(args.wandb_project).replace("\r", "").strip() + + df = load_wandb_export_csv(args.input) + n_in = len(df) + if not args.all_datasets: + df = dataframe_filter_rerun_datasets(df, allowed_hydra=DEFAULT_RERUN_ALLOWED_HYDRA_DATASETS) + print( + f"Dataset filter: {n_in} -> {len(df)} rows " + f"(main_loader.DATASETS minus cocitation cora/citeseer/pubmed; " + f"{len(DEFAULT_RERUN_ALLOWED_HYDRA_DATASETS)} allowed Hydra paths)" + ) + + if args.keep_row_seed and SEED_COLUMN not in df.columns: + print( + f"Note: --keep-row-seed but CSV has no {SEED_COLUMN!r} column " + f"(expected for seed-aggregated exports); using --data-seeds." + ) + + seeds = _parse_data_seeds(str(args.data_seeds).replace("\r", "")) + + common_kw = dict( + interpreter=args.interpreter, + data_seeds=seeds, + append_args=list(args.append_arg), + keep_row_seed=args.keep_row_seed, + group_cols=list(args.group_by), + max_epochs=int(args.max_epochs), + early_stopping_patience=args.early_stopping_patience, + fixed_args_profile=str(args.fixed_args_profile), + wandb_entity=wb_ent, + wandb_project=wb_proj, + wandb_run_name=not args.no_wandb_run_name, + ) + + n = emit_sequential_rerun_script(df, path=args.output, **common_kw) + print(f"Wrote {n} sequential command(s) -> {args.output}") + + if not args.no_parallel_script: + gpus = _parse_parallel_gpus(args.parallel_gpus) + n2 = emit_parallel_rerun_script(df, path=args.output_parallel, gpu_ids=gpus, **common_kw) + print( + f"Wrote {n2} parallel command(s) -> {args.output_parallel} " + f"(GPUS round-robin: {gpus})" + ) + + +if __name__ == "__main__": + main() diff --git a/scripts/hopse_plotting/hyperparam_analysis.py b/scripts/hopse_plotting/hyperparam_analysis.py new file mode 100644 index 000000000..fec4c13d8 --- /dev/null +++ b/scripts/hopse_plotting/hyperparam_analysis.py @@ -0,0 +1,368 @@ +#!/usr/bin/env python3 +""" +Hyperparameter sensitivity from a **seed-aggregated** W&B export (``aggregator`` +output): rows are grouped by **(model, dataset)** so the monitored validation metric +is consistent within each panel. For each group, every config column that still varies +is plotted against validation performance. + +Figures are written under ``plots/hyperparam///`` by default. + +Validation **y** uses ``dataset.parameters.monitor_metric`` and the same +``summary_*__mean`` column resolution as ``collapse_aggregated_wandb_by_best_val``. +For metrics where lower is better (see ``MONITOR_METRIC_OPTIMIZATION`` in utils), the +y-axis label adds a bold **lower is better** line. + +- Mostly numeric columns with many distinct values → scatter plot. +- Otherwise → **violin plot** per category with **jittered dots** (one point per + aggregated row / seed-mean run) on top. + +Does not modify ``main_loader``. Any model id present in the CSV (including +``simplicial/sccnn_custom``, ``cell/cwn``, ``cell/sann``, …) gets a +``(model, dataset)`` folder automatically. + +Usage:: + + python scripts/hopse_plotting/aggregator.py + python scripts/hopse_plotting/hyperparam_analysis.py + python scripts/hopse_plotting/hyperparam_analysis.py -i scripts/hopse_plotting/csvs/hopse_experiments_wandb_export_seed_agg.csv + python scripts/hopse_plotting/hyperparam_analysis.py --from-raw -i scripts/hopse_plotting/csvs/hopse_experiments_wandb_export.csv -o plots/out + python scripts/hopse_plotting/hyperparam_analysis.py --models cell/hopse_m simplicial/hopse_m + python scripts/hopse_plotting/hyperparam_analysis.py --models simplicial/sccnn_custom cell/cwn + python scripts/hopse_plotting/hyperparam_analysis.py --datasets graph/MUTAG +""" + +from __future__ import annotations + +import argparse +import zlib +from pathlib import Path + +import matplotlib + +matplotlib.use("Agg") +import matplotlib.pyplot as plt # noqa: E402 +import numpy as np # noqa: E402 +import pandas as pd # noqa: E402 + +from utils import ( + DEFAULT_AGGREGATED_EXPORT_CSV, + DEFAULT_HYPERPARAM_PLOT_DIR, + DEFAULT_WANDB_EXPORT_CSV, + MONITOR_METRIC_COLUMN, + aggregate_wandb_export_by_seed, + hyperparam_axis_columns, + infer_hyperparam_plot_kind, + load_wandb_export_csv, + metric_name_tail, + optimization_mode_for_metric_tail, + safe_filename_token, + val_metric_mean_per_row, + varied_hyperparam_columns, +) + + +def _candidates_for_model_dataset_groups(df: pd.DataFrame) -> list[str]: + """Hyperparam columns to scan; ``dataset`` is fixed per group and omitted.""" + return [c for c in hyperparam_axis_columns(df) if c != "dataset"] + + +def _stable_rng_seed(*parts: str) -> int: + h = zlib.adler32(b"\0".join(p.encode("utf-8", errors="replace") for p in parts)) & 0xFFFFFFFF + return int(h % (2**31 - 1)) or 1 + + +def _ylabel_validation_monitor(sub_all: pd.DataFrame) -> str: + """ + Y-axis label for validation (seed-mean). When the slice's monitor metric is + minimized (``MONITOR_METRIC_OPTIMIZATION``), append a **bold** mathtext line. + """ + base = "Validation metric (seed-mean)\n(row monitor)" + if MONITOR_METRIC_COLUMN not in sub_all.columns: + return base + mon = ( + sub_all[MONITOR_METRIC_COLUMN] + .dropna() + .astype(str) + .str.strip() + ) + mon = mon[(mon != "") & ~mon.str.lower().isin({"nan", "none"})] + if mon.empty: + return base + tail = metric_name_tail(mon.iloc[0]) + if not tail: + return base + if optimization_mode_for_metric_tail(tail) == "min": + return base + "\n" + r"$\bf{lower\ is\ better}$" + return base + + +def _prepare_frame( + input_path: Path, + *, + from_raw: bool, +) -> pd.DataFrame: + raw = load_wandb_export_csv(input_path) + if from_raw: + return aggregate_wandb_export_by_seed(raw) + return raw + + +def _plot_one_hyperparam( + ax: plt.Axes, + x: pd.Series, + y: pd.Series, + *, + col_name: str, +) -> None: + y_num = pd.to_numeric(y, errors="coerce") + mask_y = y_num.notna().to_numpy() + x = x[mask_y] + y = y_num[mask_y] + + if len(y) == 0: + ax.text(0.5, 0.5, "no finite y", ha="center", va="center", transform=ax.transAxes) + ax.set_axis_off() + return + + kind, xv = infer_hyperparam_plot_kind(x) + + if kind == "skip": + ax.text( + 0.5, + 0.5, + f"skipped (> categories)\n{col_name}", + ha="center", + va="center", + transform=ax.transAxes, + fontsize=9, + ) + ax.set_axis_off() + return + + if kind == "scatter": + xv_num = pd.to_numeric(xv, errors="coerce") + ok = xv_num.notna().to_numpy() & y.notna().to_numpy() + xv = xv_num[ok] + y = y[ok] + ax.scatter( + xv, + y, + s=26, + alpha=0.72, + edgecolors="0.2", + linewidths=0.45, + color="#2E6F9E", + ) + ax.set_xlabel(col_name, fontsize=8) + return + + # Violin + jittered seed-mean points per category + sub = pd.DataFrame({"x": xv.astype(str), "y": pd.to_numeric(y, errors="coerce")}) + sub = sub[np.isfinite(sub["y"])] + if sub.empty: + ax.text(0.5, 0.5, "no finite y", ha="center", va="center", transform=ax.transAxes) + ax.set_axis_off() + return + + medians = sub.groupby("x", dropna=False)["y"].median() + order = medians.sort_values(ascending=False).index.tolist() + xpos = np.arange(len(order), dtype=float) + data_arrays = [sub.loc[sub["x"] == cat, "y"].to_numpy(dtype=float) for cat in order] + + # Violins (matplotlib KDE; single-point categories still get a narrow shape) + parts = ax.violinplot( + data_arrays, + positions=xpos, + widths=min(0.82, 0.14 * max(len(order), 1)), + showmeans=False, + showmedians=True, + showextrema=False, + ) + for b in parts["bodies"]: + b.set_facecolor("#4A90A4") + b.set_alpha(0.38) + b.set_edgecolor("0.22") + b.set_linewidth(0.75) + if "cmedians" in parts and parts["cmedians"] is not None: + parts["cmedians"].set_colors("0.15") + parts["cmedians"].set_linewidths(1.0) + + rng = np.random.default_rng(_stable_rng_seed(col_name, *order[:8])) + + for i, cat in enumerate(order): + ys = sub.loc[sub["x"] == cat, "y"].to_numpy(dtype=float) + ys = ys[np.isfinite(ys)] + if ys.size == 0: + continue + jitter = rng.uniform(-0.14, 0.14, size=ys.size) + ax.scatter( + xpos[i] + jitter, + ys, + s=20, + alpha=0.88, + c="0.12", + edgecolors="white", + linewidths=0.4, + zorder=3, + ) + + ax.set_xticks(xpos) + ax.set_xticklabels([str(l) for l in order], rotation=42, ha="right", fontsize=7) + ax.set_xlim(xpos.min() - 0.65, xpos.max() + 0.65) + ax.set_xlabel(col_name, fontsize=8) + + +def run_hyperparam_analysis( + df: pd.DataFrame, + output_dir: Path, + *, + models: list[str] | None = None, + datasets: list[str] | None = None, + dpi: int = 200, +) -> None: + if "model" not in df.columns: + raise KeyError("expected column 'model' (seed-aggregated export)") + if "dataset" not in df.columns: + raise KeyError("expected column 'dataset' for (model, dataset) grouping") + + output_dir = Path(output_dir) + output_dir.mkdir(parents=True, exist_ok=True) + + plt.rcParams.update( + { + "font.family": "serif", + "font.serif": ["Times New Roman", "DejaVu Serif", "Times", "serif"], + "axes.labelsize": 9, + "axes.titlesize": 10, + "figure.dpi": dpi, + "savefig.dpi": dpi, + } + ) + + y_all = val_metric_mean_per_row(df) + candidates = _candidates_for_model_dataset_groups(df) + + keys = df[["model", "dataset"]].astype(str).drop_duplicates() + combos = [(r["model"], r["dataset"]) for _, r in keys.iterrows()] + combos.sort(key=lambda t: (t[0], t[1])) + + if models: + want_m = {str(m) for m in models} + combos = [(m, d) for m, d in combos if m in want_m] + missing_m = want_m - {m for m, _ in combos} + if missing_m: + print(f" (warn) --models not found in CSV: {sorted(missing_m)}") + + if datasets: + want_d = {str(x) for x in datasets} + combos = [(m, d) for m, d in combos if d in want_d] + missing_d = want_d - {d for _, d in combos} + if missing_d: + print(f" (warn) --datasets not found in CSV: {sorted(missing_d)}") + + for m, d in combos: + sub_all = df[(df["model"].astype(str) == m) & (df["dataset"].astype(str) == d)] + varied = varied_hyperparam_columns(sub_all, candidate_cols=candidates) + if not varied: + print(f" (skip) no varied hyperparam columns for ({m!r}, {d!r})") + continue + + y = y_all.loc[sub_all.index] + ylabel_str = _ylabel_validation_monitor(sub_all) + + safe_m = safe_filename_token(m.replace("/", "__")) + safe_d = safe_filename_token(str(d).replace("/", "__")) + combo_dir = output_dir / safe_m / safe_d + combo_dir.mkdir(parents=True, exist_ok=True) + + for col in varied: + _kind_w, _xv_w = infer_hyperparam_plot_kind(sub_all[col]) + if _kind_w == "scatter": + fig_w = 5.2 + elif _kind_w == "skip": + fig_w = 5.2 + else: + n_lab = int(_xv_w.astype(str).nunique(dropna=False)) + fig_w = min(14.0, max(5.2, 0.42 * float(max(n_lab, 1)) + 3.0)) + fig, ax = plt.subplots(figsize=(fig_w, 3.6)) + _plot_one_hyperparam(ax, sub_all[col], y, col_name=col) + ax.set_ylabel(ylabel_str, fontsize=8) + ax.set_title(f"{m}\n{d}\n{col}", fontsize=9, fontweight="semibold") + ax.yaxis.grid(True, linestyle=":", linewidth=0.5, alpha=0.85) + ax.set_axisbelow(True) + ax.spines["top"].set_visible(False) + ax.spines["right"].set_visible(False) + fig.tight_layout() + fn = f"{safe_filename_token(col, max_len=96)}.png" + fig.savefig(combo_dir / fn, bbox_inches="tight", facecolor="white", edgecolor="none") + plt.close(fig) + + print(f" Wrote {len(varied)} figure(s) -> {combo_dir}") + + +def main() -> None: + p = argparse.ArgumentParser( + description="Plot val metric vs varied hyperparams per (model, dataset) (seed-aggregated W&B CSV)." + ) + p.add_argument( + "-i", + "--input", + type=Path, + default=None, + help="CSV path (default: seed-aggregated export, or raw export with --from-raw)", + ) + p.add_argument( + "-o", + "--output-dir", + type=Path, + default=DEFAULT_HYPERPARAM_PLOT_DIR, + help=( + "Directory root; each (model, dataset) gets a subfolder model/dataset/ " + f"(default: {DEFAULT_HYPERPARAM_PLOT_DIR})" + ), + ) + p.add_argument( + "--from-raw", + action="store_true", + help="Treat --input as per-run loader export; aggregate over seeds in memory (no CSV write).", + ) + p.add_argument( + "--models", + nargs="+", + default=None, + metavar="MODEL", + help="Only these model ids (exact strings as in CSV 'model' column).", + ) + p.add_argument( + "--datasets", + nargs="+", + default=None, + metavar="DATASET", + help="Only these dataset paths (exact strings as in CSV 'dataset' column).", + ) + p.add_argument( + "--dpi", + type=int, + default=200, + help="Figure DPI (default: 200)", + ) + args = p.parse_args() + + inp = args.input + if inp is None: + inp = DEFAULT_WANDB_EXPORT_CSV if args.from_raw else DEFAULT_AGGREGATED_EXPORT_CSV + + df = _prepare_frame(inp, from_raw=args.from_raw) + print(f"Loaded {'raw→aggregated' if args.from_raw else 'seed-aggregated'} table: {len(df)} rows x {len(df.columns)} cols") + run_hyperparam_analysis( + df, + args.output_dir, + models=args.models, + datasets=args.datasets, + dpi=args.dpi, + ) + print(f"Done. Plots under {args.output_dir}") + + +if __name__ == "__main__": + main() diff --git a/scripts/hopse_plotting/main_loader.py b/scripts/hopse_plotting/main_loader.py new file mode 100644 index 000000000..9ab4f4368 --- /dev/null +++ b/scripts/hopse_plotting/main_loader.py @@ -0,0 +1,174 @@ +#!/usr/bin/env python3 +""" +Export TopoBench / HOPSE sweeps from Weights & Biases into CSV file(s). + +Expected W&B project names follow the sweep scripts: ``{model}_{dataset_basename}``, +e.g. ``graph/CYP3A4_Veith`` + model ``gin`` → project ``gin_CYP3A4_Veith``. + +Requires ``WANDB_API_KEY`` (or prior ``wandb login``) and the ``wandb`` package. + +By default writes **multiple** smaller CSVs (one per model, all datasets) under +``csvs/hopse_experiments_wandb_export_shards/`` so memory stays bounded. Use +``--shard-by none -o path.csv`` for a single monolithic export under ``csvs/``. +Run ``aggregator`` to produce **one** combined seed-aggregated CSV. + +Usage (from repo root):: + + python scripts/hopse_plotting/main_loader.py + python scripts/hopse_plotting/main_loader.py --shard-by dataset + python scripts/hopse_plotting/main_loader.py --shard-by none \\ + -o scripts/hopse_plotting/csvs/hopse_experiments_wandb_export.csv +""" + +from __future__ import annotations + +import argparse +from pathlib import Path + +from utils import ( + DEFAULT_WANDB_EXPORT_CSV, + DEFAULT_WANDB_EXPORT_SHARD_DIR, + collect_all_runs, + dataframe_from_rows, + safe_filename_token, +) + +# ----------------------------------------------------------------------------- +# Hard-coded sweep coverage (edit here) +# ----------------------------------------------------------------------------- + +WANDB_ENTITY = "gbg141-hopse" + +MODELS = ['hopse_g']#["gin", "gat", "gcn", "topotune", "hopse_m", "hopse_g", "sann", "sccnn", "cwn"] + +DATASETS = [ + "graph/MUTAG", + "graph/PROTEINS", + "graph/NCI1", + "graph/NCI109", + "simplicial/mantra_name", + "simplicial/mantra_orientation", + "simplicial/mantra_betti_numbers", + "graph/BBB_Martins", + "graph/CYP3A4_Veith", + "graph/Clearance_Hepatocyte_AZ", + "graph/Caco2_Wang", +] + +DEFAULT_OUTPUT_CSV = DEFAULT_WANDB_EXPORT_CSV + + +def main() -> None: + parser = argparse.ArgumentParser(description="Export W&B TopoBench sweeps to CSV.") + parser.add_argument( + "--output", + "-o", + type=Path, + default=DEFAULT_OUTPUT_CSV, + help=f"Single output CSV path when --shard-by none (default: {DEFAULT_OUTPUT_CSV}). Ignored when sharding.", + ) + parser.add_argument( + "--entity", + default=WANDB_ENTITY, + help="W&B entity (default: hard-coded WANDB_ENTITY)", + ) + parser.add_argument( + "--quiet", + action="store_true", + help="Less console output", + ) + parser.add_argument( + "--run-state", + default="finished", + metavar="STATE", + help='W&B run filter: "finished" (default), "running", "crashed", "failed", or "all" for no filter', + ) + parser.add_argument( + "--shard-by", + choices=("none", "model", "dataset"), + default="model", + help='Write one CSV per "model" (all datasets, default) or per "dataset" (all models). Use "none" for a single -o file.', + ) + parser.add_argument( + "--output-dir", + type=Path, + default=None, + help=f"Directory for sharded CSVs (default: {DEFAULT_WANDB_EXPORT_SHARD_DIR}). Ignored when --shard-by none.", + ) + parser.add_argument( + "--basename", + default="hopse_experiments_wandb_export", + help="File stem for sharded files (default: hopse_experiments_wandb_export).", + ) + args = parser.parse_args() + + run_state: str | None + if str(args.run_state).lower() == "all": + run_state = None + else: + run_state = str(args.run_state) + + print(f"Entity: {args.entity}") + print(f"Models ({len(MODELS)}): {MODELS}") + print(f"Datasets ({len(DATASETS)}): {DATASETS}") + + if args.shard_by == "none": + print("Collecting runs …") + rows = collect_all_runs( + args.entity, + MODELS, + DATASETS, + run_state=run_state, + verbose=not args.quiet, + ) + df = dataframe_from_rows(rows) + args.output.parent.mkdir(parents=True, exist_ok=True) + df.to_csv(args.output, index=False) + print(f"Wrote {len(df)} rows x {len(df.columns)} columns -> {args.output}") + return + + out_dir = args.output_dir or DEFAULT_WANDB_EXPORT_SHARD_DIR + out_dir.mkdir(parents=True, exist_ok=True) + verbose = not args.quiet + + if args.shard_by == "model": + print(f"Collecting runs (sharded by model) into {out_dir} …") + for model in MODELS: + if verbose: + print(f" (shard) model={model!r}") + rows = collect_all_runs( + args.entity, + [model], + DATASETS, + run_state=run_state, + verbose=verbose, + ) + df = dataframe_from_rows(rows) + stem = safe_filename_token(str(model).replace("/", "__")) + path = out_dir / f"{args.basename}__{stem}.csv" + path.parent.mkdir(parents=True, exist_ok=True) + df.to_csv(path, index=False) + print(f" -> {len(df)} rows x {len(df.columns)} columns -> {path}") + return + + print(f"Collecting runs (sharded by dataset) into {out_dir} …") + for ds in DATASETS: + if verbose: + print(f" (shard) dataset={ds!r}") + rows = collect_all_runs( + args.entity, + MODELS, + [ds], + run_state=run_state, + verbose=verbose, + ) + df = dataframe_from_rows(rows) + stem = safe_filename_token(str(ds).replace("/", "__")) + path = out_dir / f"{args.basename}__{stem}.csv" + path.parent.mkdir(parents=True, exist_ok=True) + df.to_csv(path, index=False) + print(f" -> {len(df)} rows x {len(df.columns)} columns -> {path}") + + +if __name__ == "__main__": + main() diff --git a/scripts/hopse_plotting/main_plot.py b/scripts/hopse_plotting/main_plot.py new file mode 100644 index 000000000..285ecd406 --- /dev/null +++ b/scripts/hopse_plotting/main_plot.py @@ -0,0 +1,527 @@ +#!/usr/bin/env python3 +""" +Collapse seed-aggregated W&B CSV to one best row per (model, dataset, ...), then +optionally save a publication-style figure comparing models across datasets. + +Chooses the hyperparameter row with optimal validation mean for the dataset's +``dataset.parameters.monitor_metric`` (higher-is-better vs lower-is-better from +``MONITOR_METRIC_OPTIMIZATION`` in ``utils``), using ``utils.collapse_aggregated_wandb_csv`` / +``iter_best_val_group_picks`` — the **same** best-val rule as ``table_generator`` and the +winner-row selection in ``best_rerun_sh_generator``. The output includes paired +train/val/test **mean** and **std** columns from the aggregated CSV. + +Seed aggregation and reruns assume the export includes all swept config columns listed in +``utils.CONFIG_PARAM_KEYS`` (e.g. HOPSE_G ``transforms.hopse_encoding.pretrain_model``, +SANN ``transforms.sann_encoding.*``); see ``main_loader`` / ``aggregator`` docs. + +The figure uses one column per dataset (max 4 per row), bars = models, error bars +from the **test** split by default (mean ± std); configs are still **chosen** using +validation. Override with ``--split``. y-axis shows the monitored metric with an arrow +for optimization direction. + +Default CSV paths live under ``scripts/hopse_plotting/csvs/``; default figure path is +``plots/leaderboard/_leaderboard.png``. + +Usage:: + + python scripts/hopse_plotting/main_plot.py + python scripts/hopse_plotting/main_plot.py -i scripts/hopse_plotting/csvs/hopse_experiments_wandb_export_seed_agg.csv \\ + -o scripts/hopse_plotting/csvs/hopse_experiments_wandb_export_collapsed.csv + python scripts/hopse_plotting/main_plot.py --no-plot + python scripts/hopse_plotting/main_plot.py --split val --plot-output plots/leaderboard/fig.png +""" + +from __future__ import annotations + +import argparse +import math +from collections import Counter +from pathlib import Path +from typing import Literal + +import matplotlib + +matplotlib.use("Agg") +import matplotlib.pyplot as plt # noqa: E402 +import numpy as np # noqa: E402 +import pandas as pd # noqa: E402 +from matplotlib.patches import Patch # noqa: E402 + +from utils import ( + DEFAULT_AGGREGATED_EXPORT_CSV, + DEFAULT_COLLAPSED_EXPORT_CSV, + DEFAULT_LEADERBOARD_PLOT_DIR, + MONITOR_METRIC_COLUMN, + collapse_aggregated_wandb_csv, + metric_name_tail, + optimization_mode_for_metric_tail, + safe_metric_col_token, +) + +def _label_short(s: str, max_len: int = 28) -> str: + t = str(s).strip() + if "/" in t: + t = t.rsplit("/", 1)[-1] + return t if len(t) <= max_len else t[: max_len - 1] + "..." + + +def _dataset_title(dataset_path: str) -> str: + return _label_short(dataset_path, max_len=40) + + +def _legend_labels_for_models(models: list[str]) -> dict[str, str]: + """ + One legend entry per raw model id. If several models share the same short + basename (e.g. simplicial/... and cell/...), disambiguate with the domain prefix. + """ + uniq = sorted(set(str(m) for m in models)) + short_for = {m: _label_short(m) for m in uniq} + short_count = Counter(short_for.values()) + seen: set[str] = set() + out: dict[str, str] = {} + for m in uniq: + sh = short_for[m] + if short_count[sh] > 1 and "/" in m: + domain = m.split("/", 1)[0] + label = f"{sh} ({domain})" + else: + label = sh + if label in seen: + label = m + seen.add(label) + out[m] = label + return out + + +def _model_basename(m: str) -> str: + return str(m).strip().rsplit("/", 1)[-1].lower() + + +def _model_domain(m: str) -> str: + """Hydra-style prefix: ``cell/...``, ``simplicial/...``, ``graph/...``.""" + s = str(m).strip() + if "/" not in s: + return "" + return s.split("/", 1)[0].strip().lower() + + +def model_category(m: str) -> str: + """Coarse family for color + strip ordering (MPNN, TopoTune, SANN, SCCNN, CWN, HOPSE).""" + s = str(m).lower() + b = _model_basename(m) + if "hopse_g" in s or b == "hopse_g": + return "hopse_g" + if "hopse_m" in s or b == "hopse_m": + return "hopse_m" + if "topotune" in s or b == "topotune": + return "topotune" + if b.startswith("sann"): + return "sann" + if b.startswith("sccnn"): + return "sccnn" + if b == "cwn": + return "cwn" + if b in ("gcn", "gin", "gat"): + return "mpnn" + return "other" + + +# Panel / legend order: MPNN, TopoTune, SANN, SCCNN, CWN, HOPSE-M, HOPSE-G, then other. +_CATEGORY_ORDER = { + "mpnn": 0, + "topotune": 1, + "sann": 2, + "sccnn": 3, + "cwn": 4, + "hopse_m": 5, + "hopse_g": 6, + "other": 7, +} +_MPNN_ORDER = {"gcn": 0, "gin": 1, "gat": 2} +# cell before simplicial within TopoTune / SANN / each HOPSE line +_DOMAIN_SORT = {"cell": 0, "simplicial": 1, "graph": 2} + +# MPNN: yellow -> orange -> red +_MPNN_HEX = {"gcn": "#E6C229", "gin": "#F28C18", "gat": "#C81D25"} +# TopoTune: lighter green (cell) vs deeper green (simplicial) +_TOPOTUNE_CELL_HEX = "#8FD98A" +_TOPOTUNE_SIM_HEX = "#1E7A3E" +# HOPSE blues, light -> dark: M-cell, M-sim, G-cell, G-sim +_HOPSE_M_CELL_HEX = "#C8E6F5" +_HOPSE_M_SIM_HEX = "#7DB6E8" +_HOPSE_G_CELL_HEX = "#3D78B8" +_HOPSE_G_SIM_HEX = "#0C335C" +# SANN: warm amber (cell lighter, simplicial deeper) — distinct from TopoTune greens +_SANN_CELL_HEX = "#F39C12" +_SANN_SIM_HEX = "#B9770E" +_SCCNN_HEX = "#8E44AD" +_CWN_HEX = "#1ABC9C" +_OTHER_HEX = "#6E6E6E" + + +def _domain_sort_key(m: str) -> int: + d = _model_domain(m) + return _DOMAIN_SORT.get(d, 9) + + +def model_display_sort_key(m: str) -> tuple: + cat = model_category(m) + b = _model_basename(m) + cat_i = _CATEGORY_ORDER[cat] + if cat == "mpnn": + return (cat_i, _MPNN_ORDER.get(b, 9), m.lower()) + if cat in ("topotune", "hopse_m", "hopse_g", "sann", "sccnn", "cwn"): + return (cat_i, _domain_sort_key(m), m.lower()) + return (cat_i, b, m.lower()) + + +def models_sorted_for_display(models) -> list[str]: + return sorted({str(x) for x in models}, key=model_display_sort_key) + + +def color_for_model(m: str) -> tuple: + cat = model_category(m) + b = _model_basename(m) + dom = _model_domain(m) + if cat == "mpnn": + h = _MPNN_HEX.get(b, _OTHER_HEX) + elif cat == "topotune": + if dom == "cell": + h = _TOPOTUNE_CELL_HEX + elif dom == "simplicial": + h = _TOPOTUNE_SIM_HEX + else: + h = _TOPOTUNE_SIM_HEX + elif cat == "hopse_m": + if dom == "cell": + h = _HOPSE_M_CELL_HEX + elif dom == "simplicial": + h = _HOPSE_M_SIM_HEX + else: + h = _HOPSE_M_SIM_HEX + elif cat == "hopse_g": + if dom == "cell": + h = _HOPSE_G_CELL_HEX + elif dom == "simplicial": + h = _HOPSE_G_SIM_HEX + else: + h = _HOPSE_G_SIM_HEX + elif cat == "sann": + if dom == "cell": + h = _SANN_CELL_HEX + elif dom == "simplicial": + h = _SANN_SIM_HEX + else: + h = _SANN_SIM_HEX + elif cat == "sccnn": + h = _SCCNN_HEX + elif cat == "cwn": + h = _CWN_HEX + else: + h = _OTHER_HEX + return matplotlib.colors.to_rgb(h) + + +def metric_axis_label(monitor_raw: str) -> str: + """Y-axis text: metric name + optimize direction arrow (matplotlib mathtext).""" + tail = metric_name_tail(monitor_raw) + if not tail: + tail = "metric" + mode = optimization_mode_for_metric_tail(tail) + arrow = r"$\uparrow$" if mode == "max" else r"$\downarrow$" + return f"{arrow} {tail}" + + +def _mean_std_columns_for_row( + monitor_raw: str, split: Literal["train", "val", "test"] +) -> tuple[str, str]: + tail = metric_name_tail(monitor_raw) + tok = safe_metric_col_token(tail) if tail else "unknown" + return f"{split}_{tok}_mean", f"{split}_{tok}_std" + + +# Touching grouped bars: unit spacing, full width 1.0 (edges meet between neighbors). +_BAR_TOUCHING_WIDTH = 1.0 + + +def _set_ylim_from_values_with_errors( + ax, + means: list[float], + stds: list[float], + *, + pad_frac: float = 0.06, +) -> None: + """Tight y-axis from min(mean - std) to max(mean + std) with small padding.""" + m = np.asarray(means, dtype=float) + s = np.asarray(stds, dtype=float) + ok = np.isfinite(m) + if not ok.any(): + return + s = np.nan_to_num(s, nan=0.0) + lo = float(np.min(m[ok] - s[ok])) + hi = float(np.max(m[ok] + s[ok])) + if not (np.isfinite(lo) and np.isfinite(hi)): + return + span = hi - lo + pad = pad_frac * span if span > 1e-12 else max(abs(hi), 1e-9) * pad_frac + ax.set_ylim(lo - pad, hi + pad) + + +def plot_collapsed_model_leaderboard( + collapsed_df: pd.DataFrame, + *, + path: Path, + monitor_column: str = MONITOR_METRIC_COLUMN, + split: Literal["train", "val", "test"] = "test", + max_cols: int = 4, + dpi: int = 300, +) -> None: + """ + Bar plot: facet columns = datasets (max ``max_cols`` per row), bars = models, + heights = mean, error bars = std for ``split`` (train/val/test). Bars use fixed + unit width and touch; no x ticks (identify models by legend). Y-limits are + data-driven per panel. + """ + if collapsed_df.empty: + raise ValueError("collapsed_df is empty; nothing to plot.") + if max_cols < 1: + raise ValueError("max_cols must be >= 1.") + + df = collapsed_df.copy() + if monitor_column not in df.columns: + raise KeyError(f"missing {monitor_column!r} in collapsed dataframe") + if "model" not in df.columns or "dataset" not in df.columns: + raise KeyError("collapsed dataframe must contain 'model' and 'dataset' columns") + + models_all = models_sorted_for_display(df["model"].astype(str).unique()) + color_by_model = {m: color_for_model(m) for m in models_all} + legend_label_by_model = _legend_labels_for_models(models_all) + + datasets = sorted(df["dataset"].astype(str).unique()) + n_ds = len(datasets) + n_rows = math.ceil(n_ds / max_cols) if n_ds else 1 + n_cols_fig = min(max_cols, n_ds) if n_ds else 1 + + # Publication-friendly rc + plt.rcParams.update( + { + "font.family": "serif", + "font.serif": ["Times New Roman", "DejaVu Serif", "Times", "serif"], + "axes.labelsize": 10, + "axes.titlesize": 11, + "xtick.labelsize": 9, + "ytick.labelsize": 9, + "legend.fontsize": 9, + "axes.linewidth": 0.8, + "xtick.major.width": 0.6, + "ytick.major.width": 0.6, + "figure.dpi": dpi, + "savefig.dpi": dpi, + "savefig.bbox": "tight", + } + ) + + # Wide enough for a single-row model legend when there are many models + fig_w = min(22.0, max(6.0, n_cols_fig * 3.05, 5.0 + 0.44 * len(models_all))) + fig_h = max(2.95, n_rows * 3.22) + fig, axes = plt.subplots( + n_rows, + n_cols_fig, + figsize=(fig_w, fig_h), + squeeze=False, + ) + + panel_idx = 0 + for row in range(n_rows): + for col in range(n_cols_fig): + ax = axes[row][col] + if panel_idx >= n_ds: + ax.set_visible(False) + continue + + ds = datasets[panel_idx] + sub = df[df["dataset"].astype(str) == ds].copy() + monitor = str(sub[monitor_column].iloc[0]).strip() + mean_c, std_c = _mean_std_columns_for_row(monitor, split) + + if mean_c not in sub.columns: + ax.text(0.5, 0.5, f"missing column\n{mean_c}", ha="center", va="center", transform=ax.transAxes) + ax.set_title(_dataset_title(ds)) + panel_idx += 1 + continue + + models_here = models_sorted_for_display(sub["model"].astype(str).unique()) + x = np.arange(len(models_here)) + means: list[float] = [] + stds: list[float] = [] + colors: list[tuple] = [] + for m in models_here: + row_m = sub[sub["model"].astype(str) == m].iloc[0] + mu = pd.to_numeric(row_m.get(mean_c, np.nan), errors="coerce") + sg = pd.to_numeric(row_m.get(std_c, np.nan), errors="coerce") if std_c in sub.columns else np.nan + means.append(float(mu) if pd.notna(mu) else np.nan) + stds.append(float(sg) if pd.notna(sg) else 0.0) + colors.append(color_by_model.get(m, (0.4, 0.4, 0.4))) + + n_b = len(models_here) + ax.bar( + x, + means, + width=_BAR_TOUCHING_WIDTH, + align="center", + yerr=stds, + color=colors, + edgecolor="0.12", + linewidth=0.55, + capsize=2.2, + error_kw={"elinewidth": 0.85, "capthick": 0.85, "color": "0.22"}, + ) + ax.set_xticks([]) + ax.set_xticklabels([]) + ax.tick_params(axis="x", which="both", bottom=False, top=False, labelbottom=False) + if n_b > 0: + ax.set_xlim(-0.5, (n_b - 1) + 0.5) + ax.set_title(_dataset_title(ds), fontweight="semibold", pad=6) + ax.set_ylabel(metric_axis_label(monitor)) + _set_ylim_from_values_with_errors(ax, means, stds) + ax.yaxis.grid(True, linestyle=":", linewidth=0.5, alpha=0.85) + ax.set_axisbelow(True) + ax.spines["top"].set_visible(False) + ax.spines["right"].set_visible(False) + + panel_idx += 1 + + handles = [ + Patch( + facecolor=color_by_model[m], + edgecolor="0.15", + linewidth=0.6, + label=legend_label_by_model[m], + ) + for m in models_all + ] + ncol = max(1, len(handles)) + + # Header: title pulled down, legend pulled up (nearer title); axes top below legend box + fig.subplots_adjust( + left=0.07, + right=0.99, + bottom=0.08, + top=0.795, + wspace=0.30, + hspace=0.72, + ) + + fig.suptitle( + f"Best config per model (selected on val); bars: {split} mean +/- std across seeds", + fontsize=12, + fontweight="bold", + y=0.918, + ) + fig.legend( + handles=handles, + loc="upper center", + bbox_to_anchor=(0.5, 0.888), + ncol=ncol, + frameon=False, + title="Model", + fontsize=11, + title_fontsize=11, + borderaxespad=0.15, + labelspacing=0.4, + handletextpad=0.65, + columnspacing=1.75, + ) + + path = Path(path) + path.parent.mkdir(parents=True, exist_ok=True) + fig.savefig(path, dpi=dpi, bbox_inches="tight", facecolor="white", edgecolor="none") + plt.close(fig) + + +def main() -> None: + p = argparse.ArgumentParser( + description="Collapse seed-aggregated W&B CSV and plot model comparison across datasets." + ) + p.add_argument( + "-i", + "--input", + type=Path, + default=DEFAULT_AGGREGATED_EXPORT_CSV, + help=f"Seed-aggregated CSV (default: {DEFAULT_AGGREGATED_EXPORT_CSV})", + ) + p.add_argument( + "-o", + "--output", + type=Path, + default=DEFAULT_COLLAPSED_EXPORT_CSV, + help=f"Collapsed CSV (default: {DEFAULT_COLLAPSED_EXPORT_CSV})", + ) + p.add_argument( + "--group-by", + metavar="COL", + nargs="+", + default=["model", "dataset"], + help="Group columns (default: model dataset)", + ) + p.add_argument( + "--no-plot", + action="store_true", + help="Only write collapsed CSV, skip figure.", + ) + p.add_argument( + "--plot-output", + type=Path, + default=None, + help=( + f"Figure path (.png / .pdf). Default: {DEFAULT_LEADERBOARD_PLOT_DIR}/" + "_leaderboard.png" + ), + ) + p.add_argument( + "--split", + choices=("train", "val", "test"), + default="test", + help="Which split's mean±std to plot (default: test; selection still uses val)", + ) + p.add_argument( + "--max-cols", + type=int, + default=4, + metavar="N", + help="Max dataset panels per row (default: 4)", + ) + p.add_argument( + "--dpi", + type=int, + default=300, + help="Figure DPI (default: 300)", + ) + args = p.parse_args() + + collapsed = collapse_aggregated_wandb_csv( + args.input, + args.output, + group_cols=list(args.group_by), + ) + print(f"Wrote {len(collapsed)} rows x {len(collapsed.columns)} columns -> {args.output}") + + if args.no_plot: + return + + plot_path = args.plot_output + if plot_path is None: + plot_path = DEFAULT_LEADERBOARD_PLOT_DIR / f"{args.output.stem}_leaderboard.png" + + plot_collapsed_model_leaderboard( + collapsed, + path=plot_path, + split=args.split, + max_cols=args.max_cols, + dpi=args.dpi, + ) + print(f"Wrote figure -> {plot_path}") + + +if __name__ == "__main__": + main() diff --git a/scripts/hopse_plotting/seed_n_distribution_plots.py b/scripts/hopse_plotting/seed_n_distribution_plots.py new file mode 100644 index 000000000..24a2ff0d5 --- /dev/null +++ b/scripts/hopse_plotting/seed_n_distribution_plots.py @@ -0,0 +1,274 @@ +#!/usr/bin/env python3 +""" +Per-model bar plots of **n_seeds** (raw-run count per hyperparameter group) vs +**n_groups** (how many groups have that count), for every dataset in the report. + +Uses the same pre-filter aggregate as ``aggregator.py`` (``build_seed_bucket_report``): +one subplot per (model, dataset) pair for that model. + +Standalone (same input discovery as ``aggregator.py``):: + + python scripts/hopse_plotting/seed_n_distribution_plots.py + python scripts/hopse_plotting/seed_n_distribution_plots.py -i shards/a.csv shards/b.csv + python scripts/hopse_plotting/seed_n_distribution_plots.py --output-dir plots/custom + +Or enable ``--plot-seed-distributions`` when running ``aggregator.py`` (writes here by default). +""" + +from __future__ import annotations + +import argparse +import math +from pathlib import Path + +import matplotlib + +matplotlib.use("Agg") +import matplotlib.pyplot as plt # noqa: E402 +import pandas as pd # noqa: E402 + +from utils import ( + DEFAULT_WANDB_EXPORT_CSV, + DEFAULT_WANDB_EXPORT_SHARD_DIR, + PLOTS_DIR, + aggregate_wandb_export_by_seed, + build_seed_bucket_report, + load_wandb_export_csv, + safe_filename_token, + _union_column_order, +) + +DEFAULT_SEED_N_DIST_DIR = PLOTS_DIR / "seed_n_distributions" + + +def _collect_input_paths( + *, + explicit: list[Path] | None, + input_dir: Path | None, + input_pattern: str, +) -> list[Path]: + paths: list[Path] = [] + if explicit: + paths.extend(explicit) + if input_dir is not None: + d = Path(input_dir) + if d.is_dir(): + paths.extend(sorted(d.glob(input_pattern))) + if not paths: + paths = [DEFAULT_WANDB_EXPORT_CSV] + seen: set[Path] = set() + uniq: list[Path] = [] + for p in paths: + rp = p.resolve() + if rp not in seen: + seen.add(rp) + uniq.append(p) + return uniq + + +def seed_bucket_report_from_export_paths( + paths: list[Path], +) -> pd.DataFrame: + """Match aggregator: aggregate by seed, then ``build_seed_bucket_report`` (no ``--required-seeds`` cut).""" + if not paths: + raise ValueError("no input paths") + if len(paths) == 1: + df = load_wandb_export_csv(paths[0]) + agg = aggregate_wandb_export_by_seed(df) + return build_seed_bucket_report(agg) + + frames: list[pd.DataFrame] = [] + for p in paths: + df = load_wandb_export_csv(p) + frames.append(aggregate_wandb_export_by_seed(df)) + cols = _union_column_order(frames) + out = pd.concat(frames, ignore_index=True, sort=False) + out = out.reindex(columns=cols) + return build_seed_bucket_report(out) + + +def write_seed_distribution_plots( + report: pd.DataFrame, + out_dir: Path, + *, + required_n_seeds: int | None = None, + dpi: int = 150, +) -> int: + """ + Write one PNG per distinct ``model`` in ``report``. + + Each figure: grid of bar charts (x = ``n_seeds``, height = ``n_groups``) for + every ``dataset`` that model appears in. Bars matching ``required_n_seeds`` + (if not None) are highlighted. + """ + out_dir = Path(out_dir) + out_dir.mkdir(parents=True, exist_ok=True) + + if report.empty or "model" not in report.columns: + return 0 + + plt.rcParams.update( + { + "font.family": "serif", + "font.serif": ["DejaVu Serif", "Times New Roman", "serif"], + "axes.titlesize": 9, + "axes.labelsize": 8, + "figure.dpi": dpi, + "savefig.dpi": dpi, + } + ) + + n_written = 0 + for model in sorted(report["model"].astype(str).unique()): + sub_m = report[report["model"].astype(str) == model] + datasets = sorted(sub_m["dataset"].astype(str).unique()) + n_ds = len(datasets) + if n_ds == 0: + continue + + n_cols = min(4, n_ds) + n_rows = math.ceil(n_ds / n_cols) + fig_w = max(7.0, n_cols * 3.15) + fig_h = max(3.0, n_rows * 2.75) + fig, axes = plt.subplots( + n_rows, + n_cols, + figsize=(fig_w, fig_h), + squeeze=False, + ) + + for idx, ds in enumerate(datasets): + r, c = divmod(idx, n_cols) + ax = axes[r][c] + piece = sub_m[sub_m["dataset"].astype(str) == ds].copy() + piece = piece.sort_values("n_seeds") + if piece.empty: + ax.text(0.5, 0.5, "no data", ha="center", va="center", transform=ax.transAxes, fontsize=9) + ax.set_title(_dataset_short_title(ds), fontsize=9, fontweight="semibold") + ax.set_axis_off() + continue + piece["n_seeds"] = pd.to_numeric(piece["n_seeds"], errors="coerce") + piece["n_groups"] = pd.to_numeric(piece["n_groups"], errors="coerce").fillna(0) + piece = piece.dropna(subset=["n_seeds"]) + + x = piece["n_seeds"].astype(int).astype(str).tolist() + y = piece["n_groups"].astype(float).tolist() + colors: list[str] = [] + for ns in piece["n_seeds"].astype(int): + if required_n_seeds is not None and int(ns) == int(required_n_seeds): + colors.append("#E74C3C") + else: + colors.append("#4A90A4") + + ax.bar(x, y, color=colors, edgecolor="0.2", linewidth=0.45) + ax.set_title(_dataset_short_title(ds), fontsize=9, fontweight="semibold") + ax.set_xlabel("n_seeds (runs / group)", fontsize=8) + ax.set_ylabel("# groups", fontsize=8) + ax.yaxis.grid(True, linestyle=":", linewidth=0.45, alpha=0.85) + ax.set_axisbelow(True) + ax.spines["top"].set_visible(False) + ax.spines["right"].set_visible(False) + + for j in range(n_ds, n_rows * n_cols): + r, c = divmod(j, n_cols) + axes[r][c].set_visible(False) + + fig.suptitle(str(model), fontsize=11, fontweight="bold", y=1.02) + if required_n_seeds is not None: + fig.text( + 0.5, + 0.01, + f"Red bars: n_seeds == {required_n_seeds} (aggregator default filter)", + ha="center", + fontsize=8, + style="italic", + ) + fig.subplots_adjust(bottom=0.12, top=0.92) + else: + fig.subplots_adjust(bottom=0.08, top=0.92) + + stem = safe_filename_token(str(model).replace("/", "__"), max_len=96) + out_path = out_dir / f"{stem}_n_seeds.png" + fig.savefig(out_path, bbox_inches="tight", facecolor="white", edgecolor="none") + plt.close(fig) + n_written += 1 + + return n_written + + +def _dataset_short_title(dataset_path: str) -> str: + t = str(dataset_path).strip() + if "/" in t: + t = t.rsplit("/", 1)[-1] + return t if len(t) <= 36 else t[:33] + "..." + + +def main() -> None: + p = argparse.ArgumentParser( + description="Bar plots of n_seeds distribution per model (all datasets as subplots)." + ) + p.add_argument( + "-i", + "--input", + action="append", + type=Path, + default=None, + metavar="PATH", + help="Per-run export CSV (repeat for shards). Default: shard dir or monolithic export.", + ) + p.add_argument( + "--input-dir", + type=Path, + default=None, + help="Glob CSVs under this directory (default: shard folder if present).", + ) + p.add_argument( + "--input-pattern", + default="*.csv", + help="Glob under --input-dir (default: *.csv)", + ) + p.add_argument( + "--output-dir", + type=Path, + default=DEFAULT_SEED_N_DIST_DIR, + help=f"Directory for PNGs (default: {DEFAULT_SEED_N_DIST_DIR})", + ) + p.add_argument( + "--required-seeds", + type=int, + default=5, + metavar="N", + help="Highlight bars where n_seeds equals N (aggregator default); use -1 to disable.", + ) + p.add_argument( + "--dpi", + type=int, + default=150, + help="Figure DPI (default: 150)", + ) + args = p.parse_args() + + input_dir = args.input_dir + if args.input is None and input_dir is None: + sd = DEFAULT_WANDB_EXPORT_SHARD_DIR + if sd.is_dir() and any(sd.glob(args.input_pattern)): + input_dir = sd + + paths = _collect_input_paths( + explicit=args.input, + input_dir=input_dir, + input_pattern=args.input_pattern, + ) + report = seed_bucket_report_from_export_paths(paths) + req = int(args.required_seeds) if int(args.required_seeds) >= 0 else None + n = write_seed_distribution_plots( + report, + args.output_dir, + required_n_seeds=req, + dpi=int(args.dpi), + ) + print(f"Wrote {n} figure(s) -> {args.output_dir}") + + +if __name__ == "__main__": + main() diff --git a/scripts/hopse_plotting/table_generator.py b/scripts/hopse_plotting/table_generator.py new file mode 100644 index 000000000..6efe01d3c --- /dev/null +++ b/scripts/hopse_plotting/table_generator.py @@ -0,0 +1,840 @@ +#!/usr/bin/env python3 +""" +Build a LaTeX table (booktabs / multirow / cell colors) from a **seed-aggregated** +W&B CSV: for each (model, dataset), pick the hyperparameter row with best **validation** +mean — same implementation as ``main_plot`` / ``collapse_aggregated_wandb_csv`` via +``utils.iter_best_val_group_picks`` (default ``group_cols``: ``model``, ``dataset``; +``monitor_column``: ``dataset.parameters.monitor_metric``), then read **test** +``test_best_rerun`` mean ± std from that row. + +The seed-aggregated CSV must include every sweep axis you care about (see +``utils.CONFIG_PARAM_KEYS`` and ``main_loader`` export), or distinct configs can collapse +during aggregation (e.g. missing ``transforms.hopse_encoding.pretrain_model`` for HOPSE_G). + +- **bestgray** + bold: best test value in the column (ties share the style). +- **stdblue**: not significantly different from the column best at 95% confidence + (two-sided Z on independent means in code; SE = seed-agg std / sqrt(n_seeds)). + +Model blocks: graph GCN/GAT/GIN; simplicial HOPSE-M, HOPSE-G, TopoTune, SCCNN +(``simplicial/sccnn_custom``); cell HOPSE-M, HOPSE-G, TopoTune, CWN (``cell/cwn``). +**Dataset columns** come from ``DATASETS`` in ``main_loader.py``, reordered so **all graph +columns precede all simplicial**. By default **four** ``.tex`` files are written under ``tables/``: + +- Base: ``main_table_all.tex`` / ``main_table_no_transductive.tex`` (one row per model). +- **Submodels**: ``main_table_all_submodels.tex`` / ``main_table_no_transductive_submodels.tex`` — + GNN rows split by ``transforms`` (empty → plain name; ``combined_fe`` → ``-F``; ``combined_pe`` → ``-PE``); + HOPSE-M split by ``model.preprocessing_params.encodings`` (**HOPSE-M-F** if HFKE/HKFE appears in the + cell, else **HOPSE-M-PE**); HOPSE-G and TopoTune unchanged. Best validation row is chosen **within** + each sub-row group. Use ``--skip-submodel-tables`` to emit only the base pair. + +Usage:: + + python scripts/hopse_plotting/table_generator.py + python scripts/hopse_plotting/table_generator.py -o scripts/hopse_plotting/tables/main_table_all.tex \\ + --output-without-transductive scripts/hopse_plotting/tables/main_table_no_transductive.tex + python scripts/hopse_plotting/table_generator.py --stdout + python scripts/hopse_plotting/table_generator.py --skip-submodel-tables + python scripts/hopse_plotting/table_generator.py --group-by model dataset +""" + +from __future__ import annotations + +import argparse +import math +import sys +from pathlib import Path +from typing import Any, Literal + +import pandas as pd + +from main_loader import DATASETS as LOADER_DATASETS +from utils import ( + DEFAULT_AGGREGATED_EXPORT_CSV, + MONITOR_METRIC_COLUMN, + TABLES_DIR, + _first_existing_column, + _paired_std_from_mean, + _test_mean_columns_for_tail, + hydra_dataset_key_from_loader_identity, + iter_best_val_group_picks, + optimization_mode_for_metric_tail, + load_wandb_export_csv, + safe_filename_token, +) + +DEFAULT_LATEX_TABLE_TEX = TABLES_DIR / "main_table_all.tex" +DEFAULT_LATEX_TABLE_TEX_NO_TRANS = TABLES_DIR / "main_table_no_transductive.tex" +DEFAULT_LATEX_TABLE_TEX_SUBMODELS = TABLES_DIR / "main_table_all_submodels.tex" +DEFAULT_LATEX_TABLE_TEX_NO_TRANS_SUBMODELS = TABLES_DIR / "main_table_no_transductive_submodels.tex" + +COL_TRANSFORMS = "transforms" +COL_PREPROC_ENC = "model.preprocessing_params.encodings" + +GRAPH_MPNN = frozenset({"graph/gcn", "graph/gat", "graph/gin"}) +MODEL_HOPSE_M = frozenset({"simplicial/hopse_m", "cell/hopse_m"}) +MODEL_HOPSE_G_TOPO = frozenset( + { + "simplicial/hopse_g", + "cell/hopse_g", + "simplicial/topotune", + "cell/topotune", + } +) + +# Planetoid cocitation configs (transductive); must match loader ``graph/cocitation_*`` paths. +TRANSDUCTIVE_GRAPH_PATHS: tuple[str, ...] = ( + "graph/cocitation_cora", + "graph/cocitation_citeseer", + "graph/cocitation_pubmed", +) +TRANSDUCTIVE_GRAPH_SET: frozenset[str] = frozenset(TRANSDUCTIVE_GRAPH_PATHS) + +Z_CRIT_95 = 1.959963984540054 + +# W&B often stores 0–1 fractions; publication tables use 0–100 for these tails. +_DISPLAY_SCALE_100: frozenset[str] = frozenset( + {"accuracy", "f1", "precision", "recall", "auroc", "roc_auc"} +) + + +def _display_scale(tail: str) -> float: + t = (tail or "").strip().lower() + return 100.0 if t in _DISPLAY_SCALE_100 else 1.0 + + +def _finite(x: Any) -> bool: + try: + v = float(x) + except (TypeError, ValueError): + return False + return math.isfinite(v) + + +def _sem(std: float, n: int) -> float: + if n <= 0 or not _finite(std): + return 0.0 + return float(std) / math.sqrt(float(n)) + + +def _z_two_sample(mu_i: float, se_i: float, mu_j: float, se_j: float) -> float: + v = se_i * se_i + se_j * se_j + if v <= 0.0: + return 0.0 if abs(mu_i - mu_j) < 1e-12 else float("inf") + return abs(mu_i - mu_j) / math.sqrt(v) + + +def _not_sig_diff_from_best(mu: float, se: float, best_mu: float, best_se: float) -> bool: + return _z_two_sample(mu, se, best_mu, best_se) <= Z_CRIT_95 + + +def _parse_dataset_specs(items: list[str]) -> list[tuple[str, str]]: + """``path:LaTeX header`` or bare ``path`` (basename used as header).""" + out: list[tuple[str, str]] = [] + for raw in items: + s = raw.strip() + if not s: + continue + if ":" in s: + path, hdr = s.split(":", 1) + out.append((path.strip(), hdr.strip())) + else: + base = s.rsplit("/", 1)[-1] + out.append((s.strip(), base)) + return out + + +# Short column titles (↑ / ↓ are heuristic for table headers; per-column optimization still comes from data). +_DATASET_COLUMN_LABEL: dict[str, str] = { + "graph/MUTAG": "MUTAG", + "graph/cocitation_cora": "Cora", + "graph/PROTEINS": "PROTEINS", + "graph/NCI1": "NCI1", + "graph/NCI109": "NCI109", + "graph/cocitation_citeseer": "Citeseer", + "graph/cocitation_pubmed": "PubMed", + "simplicial/mantra_name": "NAME", + "simplicial/mantra_orientation": "ORIENT", + "simplicial/mantra_betti_numbers": r"$\beta$", + "graph/BBB_Martins": "BBB", + "graph/CYP3A4_Veith": "CYP3A4", + "graph/Clearance_Hepatocyte_AZ": "Cl.Hep.", + "graph/Caco2_Wang": "Caco2", +} + +_DATASET_MIN_ARROW: frozenset[str] = frozenset( + { + "graph/Clearance_Hepatocyte_AZ", + "graph/Caco2_Wang", + "simplicial/mantra_betti_numbers", + } +) + + +def _latex_short_dataset_label(path: str) -> str: + return _DATASET_COLUMN_LABEL.get(path, path.rsplit("/", 1)[-1].replace("_", r"\_")) + + +def _auto_header_for_dataset_path(path: str) -> str: + short = _latex_short_dataset_label(path) + arr = r"$\downarrow$" if path in _DATASET_MIN_ARROW else r"$\uparrow$" + return f"{short} ({arr})" + + +def _specs_from_loader_paths() -> list[tuple[str, str]]: + return [ + (p.strip(), _auto_header_for_dataset_path(p.strip())) + for p in LOADER_DATASETS + if p.strip() + ] + + +def partition_specs_three_way( + specs: list[tuple[str, str]], +) -> list[tuple[str, list[tuple[str, str]]]]: + """ + Graph (transductive) → Graph (inductive) → Simplicial (inductive). + Transductive columns are only Cora / Citeseer / PubMed (``cocitation_*``), in that order. + """ + by_path = {p: h for p, h in specs} + trans = [(p, by_path[p]) for p in TRANSDUCTIVE_GRAPH_PATHS if p in by_path] + graph_ind: list[tuple[str, str]] = [] + simplicial: list[tuple[str, str]] = [] + for p, h in specs: + if p in TRANSDUCTIVE_GRAPH_SET: + continue + if p.startswith("graph/"): + graph_ind.append((p, h)) + elif p.startswith("simplicial/"): + simplicial.append((p, h)) + blocks = [ + ("Graph (transductive)", trans), + ("Graph (inductive)", graph_ind), + ("Simplicial (inductive)", simplicial), + ] + return [(title, blk) for title, blk in blocks if blk] + + +def partition_specs_two_way_no_transductive( + specs: list[tuple[str, str]], +) -> list[tuple[str, list[tuple[str, str]]]]: + """Graph then Simplicial; omits cocitation Cora/Citeseer/PubMed. Headers omit ``(inductive)``.""" + graph_ind: list[tuple[str, str]] = [] + simplicial: list[tuple[str, str]] = [] + for p, h in specs: + if p in TRANSDUCTIVE_GRAPH_SET: + continue + if p.startswith("graph/"): + graph_ind.append((p, h)) + elif p.startswith("simplicial/"): + simplicial.append((p, h)) + blocks = [ + ("Graph", graph_ind), + ("Simplicial", simplicial), + ] + return [(title, blk) for title, blk in blocks if blk] + + +def _is_empty_transforms(val: Any) -> bool: + if val is None: + return True + try: + if pd.isna(val): + return True + except (TypeError, ValueError): + pass + s = str(val).replace("\r", "").strip().lower() + return s in ("", "nan", "none", "[]", "{}", "null") + + +def _graph_transforms_sub_id(val: Any) -> str: + """Bucket GNN rows by ``transforms`` for separate best-val picks.""" + if _is_empty_transforms(val): + return "base" + s = str(val).replace("\r", "").strip().lower() + if s == "combined_fe": + return "fe" + if s == "combined_pe": + return "pe" + tok = safe_filename_token(s, max_len=48) + return f"other::{tok}" + + +def _hopse_m_enc_sub_id(val: Any) -> str: + """ + HOPSE-M: HFKE (or HKFE as stored in exports) in ``model.preprocessing_params.encodings`` + → ``f`` (display HOPSE-M-F), else ``pe`` (HOPSE-M-PE). + """ + s = str(val if val is not None else "").replace("\r", "") + su = s.upper() + if "HFKE" in su or "HKFE" in su: + return "f" + return "pe" + + +def _assign_sub_id_for_row(model: str, row: pd.Series) -> str: + m = str(model).strip() + if m in GRAPH_MPNN: + tv = row[COL_TRANSFORMS] if COL_TRANSFORMS in row.index else None + return _graph_transforms_sub_id(tv) + if m in MODEL_HOPSE_M: + ev = row[COL_PREPROC_ENC] if COL_PREPROC_ENC in row.index else None + return _hopse_m_enc_sub_id(ev) + if m in MODEL_HOPSE_G_TOPO: + return "default" + return "default" + + +def dataframe_with_submodel_id(df: pd.DataFrame) -> pd.DataFrame: + out = df.copy() + subs: list[str] = [] + for _idx, row in out.iterrows(): + subs.append(_assign_sub_id_for_row(str(row.get("model", "")), row)) + out["_sub_id"] = subs + return out + + +def _sort_graph_sub_ids(subs: set[str]) -> list[str]: + def sk(s: str) -> tuple[int, str]: + if s == "base": + return (0, s) + if s == "fe": + return (1, s) + if s == "pe": + return (2, s) + return (3, s) + + return sorted(subs, key=sk) + + +def _latex_graph_sub_row_label(base_short: str, sub_id: str) -> str: + if sub_id == "base": + return base_short + if sub_id == "fe": + return f"{base_short}-F" + if sub_id == "pe": + return f"{base_short}-PE" + if sub_id.startswith("other::"): + body = sub_id.split("other::", 1)[1].replace("_", r"\_") + return f"{base_short}-\\texttt{{{body}}}" + body = sub_id.replace("_", r"\_") + return f"{base_short}-\\texttt{{{body}}}" + + +def graph_submodel_table_rows(stats: dict[tuple[str, str], Any]) -> list[tuple[str, str]]: + """(stats_row_key, LaTeX label) for GCN/GAT/GIN sub-rows.""" + seen_keys = {k[0] for k in stats} + templates = [("graph/gcn", "GCN"), ("graph/gat", "GAT"), ("graph/gin", "GIN")] + rows: list[tuple[str, str]] = [] + for mid, lab in templates: + subs = {rk.split("|", 1)[-1] for rk in seen_keys if rk.startswith(mid + "|")} + if not subs: + subs = {"base"} + for sub in _sort_graph_sub_ids(subs): + rows.append((f"{mid}|{sub}", _latex_graph_sub_row_label(lab, sub))) + return rows + + +def simplicial_submodel_table_rows() -> list[tuple[str, str]]: + return [ + (r"simplicial/hopse_m|f", r"\textbf{HOPSE-M-F} (Our)"), + (r"simplicial/hopse_m|pe", r"\textbf{HOPSE-M-PE} (Our)"), + (r"simplicial/hopse_g|default", r"\textbf{HOPSE-G} (Our)"), + (r"simplicial/topotune|default", "TopoTune"), + (r"simplicial/sccnn_custom|default", "SCCNN"), + ] + + +def cell_submodel_table_rows() -> list[tuple[str, str]]: + return [ + (r"cell/hopse_m|f", r"\textbf{HOPSE-M-F} (Our)"), + (r"cell/hopse_m|pe", r"\textbf{HOPSE-M-PE} (Our)"), + (r"cell/hopse_g|default", r"\textbf{HOPSE-G} (Our)"), + (r"cell/topotune|default", "TopoTune"), + (r"cell/cwn|default", "CWN"), + ] + + +def collect_winner_test_by_model_dataset( + df: pd.DataFrame, + *, + group_cols: tuple[str, ...] = ("model", "dataset"), +) -> dict[tuple[str, str], dict[str, Any]]: + """ + (model, dataset_canon) -> {test_mean, test_std, n_seeds, tail, mode, monitor_raw}. + + ``group_cols`` must include ``model`` and ``dataset`` (same contract as + ``main_plot --group-by`` / ``collapse_aggregated_wandb_by_best_val``). + """ + if "model" not in group_cols or "dataset" not in group_cols: + raise ValueError("collect_winner_test_by_model_dataset: group_cols must include 'model' and 'dataset'") + colset = set(df.columns) + out: dict[tuple[str, str], dict[str, Any]] = {} + for keys, pick_idx, monitor_val, tail in iter_best_val_group_picks( + df, group_cols=list(group_cols), monitor_column=MONITOR_METRIC_COLUMN + ): + gk = keys if isinstance(keys, tuple) else (keys,) + if len(gk) != len(group_cols): + raise RuntimeError("groupby key length mismatch vs group_cols") + zd = dict(zip(group_cols, gk, strict=True)) + model = str(zd["model"]).strip() + dataset_raw = str(zd["dataset"]).strip() + dataset = hydra_dataset_key_from_loader_identity(dataset_raw) + w = df.loc[pick_idx] + mode: Literal["max", "min"] = optimization_mode_for_metric_tail(tail) if tail else "max" + test_src = _first_existing_column(_test_mean_columns_for_tail(tail), colset) + te_std = _paired_std_from_mean(test_src, colset) if test_src else None + mu = pd.to_numeric(w.get(test_src), errors="coerce") if test_src else float("nan") + sd = pd.to_numeric(w.get(te_std), errors="coerce") if te_std else float("nan") + n_raw = w.get("n_seeds", float("nan")) + n = int(pd.to_numeric(n_raw, errors="coerce")) if _finite(n_raw) else 0 + out[(model, dataset)] = { + "test_mean": float(mu) if pd.notna(mu) else float("nan"), + "test_std": float(sd) if pd.notna(sd) else float("nan"), + "n_seeds": max(n, 0), + "tail": tail, + "mode": mode, + "monitor_raw": str(monitor_val).strip(), + } + return out + + +def collect_winner_test_by_submodel(df: pd.DataFrame) -> dict[tuple[str, str], dict[str, Any]]: + """ + Like ``collect_winner_test_by_model_dataset`` but groups by (model, dataset, _sub_id). + + Row keys in the returned map are ``f"{model}|{sub_id}"`` where ``sub_id`` comes from + ``transforms`` (GNN) or ``model.preprocessing_params.encodings`` (HOPSE-M), or + ``default`` for HOPSE-G / TopoTune. + """ + work = dataframe_with_submodel_id(df) + colset = set(work.columns) + out: dict[tuple[str, str], dict[str, Any]] = {} + gc = ["model", "dataset", "_sub_id"] + for keys, pick_idx, monitor_val, tail in iter_best_val_group_picks( + work, group_cols=gc, monitor_column=MONITOR_METRIC_COLUMN + ): + model = str(keys[0]).strip() + dataset_raw = str(keys[1]).strip() + sub_id = str(keys[2]).strip() + dataset = hydra_dataset_key_from_loader_identity(dataset_raw) + row_key = f"{model}|{sub_id}" + w = work.loc[pick_idx] + mode: Literal["max", "min"] = optimization_mode_for_metric_tail(tail) if tail else "max" + test_src = _first_existing_column(_test_mean_columns_for_tail(tail), colset) + te_std = _paired_std_from_mean(test_src, colset) if test_src else None + mu = pd.to_numeric(w.get(test_src), errors="coerce") if test_src else float("nan") + sd = pd.to_numeric(w.get(te_std), errors="coerce") if te_std else float("nan") + n_raw = w.get("n_seeds", float("nan")) + n = int(pd.to_numeric(n_raw, errors="coerce")) if _finite(n_raw) else 0 + out[(row_key, dataset)] = { + "test_mean": float(mu) if pd.notna(mu) else float("nan"), + "test_std": float(sd) if pd.notna(sd) else float("nan"), + "n_seeds": max(n, 0), + "tail": tail, + "mode": mode, + "monitor_raw": str(monitor_val).strip(), + } + return out + + +def _fmt_cell(mu: float, sd: float, *, decimals: int = 2, scale: float = 1.0) -> str: + if not _finite(mu): + return "-" + mu *= scale + sd = sd * scale if _finite(sd) else float("nan") + # \pm must be in math mode (text-mode triggers "Missing $ inserted"). + if _finite(sd): + return f"${mu:.{decimals}f} \\pm {sd:.{decimals}f}$" + return f"${mu:.{decimals}f}$" + + +def _latex_cell_body( + mu: float, + sd: float, + se: float, + *, + is_best: bool, + blue_tie: bool, + decimals: int, + scale: float, +) -> str: + body = _fmt_cell(mu, sd, decimals=decimals, scale=scale) + if body == "-": + return "-" + inner = f"{{\\scriptsize {body}}}" + if is_best: + # \textbf does not bold math digits; \boldmath applies to the following math. + return f"{{\\cellcolor{{bestgray}}{{\\scriptsize\\boldmath {body}}}}}" + if blue_tie: + return f"\\cellcolor{{stdblue}}{inner}" + return inner + + +def build_latex_table( + stats: dict[tuple[str, str], dict[str, Any]], + *, + column_groups: list[tuple[str, list[tuple[str, str]]]], + graph_rows: list[tuple[str, str]], + simplicial_rows: list[tuple[str, str]], + cell_rows: list[tuple[str, str]], + decimals: int = 2, + scale_fraction_metrics: bool = True, + label: str = "tbl:hopse_wandb_graph_trans_ind_sim", +) -> str: + """ + Return full LaTeX fragment (table env + suggested \\definecolor comments). + + Each of ``graph_rows`` / ``simplicial_rows`` / ``cell_rows`` is + ``(stats_row_key, latex_model_label)``. Base tables use ``stats_row_key == model`` + (e.g. ``graph/gcn``); submodel tables use keys like ``graph/gcn|fe``. + """ + dataset_specs: list[tuple[str, str]] = [] + group_ranges: list[tuple[str, int, int]] = [] + for title, block in column_groups: + if not block: + continue + i0 = len(dataset_specs) + dataset_specs.extend(block) + group_ranges.append((title, i0, len(dataset_specs) - 1)) + + n_d = len(dataset_specs) + colspec = "@{}ll" + "c" * n_d + "@{}" + + all_row_keys = [rk for rk, _ in graph_rows + simplicial_rows + cell_rows] + + def cell_colored(row_key: str, ds_path: str) -> str: + ds_key = hydra_dataset_key_from_loader_identity(ds_path) + st = stats.get((row_key, ds_key)) + if not st or not _finite(st.get("test_mean", float("nan"))): + return "-" + mu = float(st["test_mean"]) + sd = float(st["test_std"]) if _finite(st.get("test_std")) else float("nan") + se = _sem(sd, int(st.get("n_seeds", 0))) + mode = st.get("mode", "max") + tail = str(st.get("tail", "")) + dsc = _display_scale(tail) if scale_fraction_metrics else 1.0 + + mus: list[float] = [] + for rk in all_row_keys: + t = stats.get((rk, ds_key)) + if t and _finite(t.get("test_mean")): + mus.append(float(t["test_mean"])) + if not mus: + return _latex_cell_body( + mu, sd, se, is_best=False, blue_tie=False, decimals=decimals, scale=dsc + ) + + best_val = max(mus) if mode == "max" else min(mus) + is_best = abs(mu - best_val) <= 1e-9 * (1 + abs(best_val)) + + ref_mu, ref_se = best_val, 0.0 + for rk in all_row_keys: + t = stats.get((rk, ds_key)) + if not t or not _finite(t.get("test_mean")): + continue + if abs(float(t["test_mean"]) - best_val) > 1e-9 * (1 + abs(best_val)): + continue + ref_mu = float(t["test_mean"]) + ref_se = _sem( + float(t["test_std"]) if _finite(t.get("test_std")) else 0.0, + int(t.get("n_seeds", 0)), + ) + break + + blue = not is_best and _not_sig_diff_from_best(mu, se, ref_mu, ref_se) + return _latex_cell_body( + mu, sd, se, is_best=is_best, blue_tie=blue, decimals=decimals, scale=dsc + ) + + lines: list[str] = [] + lines.append("% --- Requires: \\usepackage{booktabs,multirow,adjustbox,graphicx,xcolor,colortbl}") + lines.append( + "\\definecolor{stdblue}{HTML}{C9DAF8}% same swatch as non-significant cells (tweak to match venue)" + ) + lines.append("\\definecolor{bestgray}{HTML}{D9D9D9}") + lines.append("\\begin{table}[t]") + cap = ( + "Test mean $\\pm$ std over seeds (hyperparameters chosen on validation). " + "Best mean result per dataset highlighted in \\textbf{Bold}. " + "Results in \\protect\\colorbox{stdblue}{blue} are not significantly different " + "from best model (95\\,\\% confidence)." + ) + lines.append(f"\\caption{{{cap}}}") + lines.append(f"\\label{{{label}}}") + lines.append("\\centering") + lines.append("\\begin{adjustbox}{width=1.\\textwidth}") + # Must be *outside* tabular: a \\renewcommand right after \\begin{tabular} can break the + # alignment (Misplaced \\cr / \\noalign) when the next row uses \\multicolumn + \\cmidrule. + lines.append("\\renewcommand{\\arraystretch}{1.4}") + lines.append(f"\\begin{{tabular}}{{{colspec}}}") + lines.append("\\toprule") + + if n_d > 0 and group_ranges: + multicols = [] + cmid_parts = [] + for title, i0, i1 in group_ranges: + span = i1 - i0 + 1 + # \\mbox isolates parentheses from babel / chemistry packages that treat "(" specially. + multicols.append(f"\\multicolumn{{{span}}}{{c}}{{\\mbox{{{title}}}}}") + cmid_parts.append(f"\\cmidrule(lr){{{3 + i0}-{3 + i1}}}") + lines.append(" & & " + " & ".join(multicols) + " \\\\") + lines.append(" ".join(cmid_parts)) + + hdr = " & \\textbf{Model}" + for _p, h in dataset_specs: + hdr += f" & \\scriptsize {h}" + hdr += " \\\\" + lines.append(hdr) + lines.append("\\midrule") + + def emit_model_block(rotate: str, rows: list[tuple[str, str]]) -> None: + n_r = len(rows) + rk0, lab0 = rows[0] + row = ( + f"\\multirow{{{n_r}}}{{*}}{{\\rotatebox[origin=c]{{90}}{{\\textbf{{{rotate}}}}}}} " + f"& {lab0}" + ) + for ds_path, _h in dataset_specs: + row += " & " + cell_colored(rk0, ds_path) + lines.append(row + " \\\\") + for rk, lab in rows[1:]: + row = f"& {lab}" + for ds_path, _h in dataset_specs: + row += " & " + cell_colored(rk, ds_path) + lines.append(row + " \\\\") + + emit_model_block("Graph", graph_rows) + lines.append("\\midrule") + emit_model_block("Simplicial", simplicial_rows) + lines.append("\\midrule") + emit_model_block("Cell", cell_rows) + + lines.append("\\bottomrule") + lines.append("\\end{tabular}") + lines.append("\\end{adjustbox}") + lines.append("\\end{table}") + return "\n".join(lines) + "\n" + + +def main() -> None: + p = argparse.ArgumentParser(description="Emit LaTeX leaderboard table from seed-aggregated W&B CSV.") + p.add_argument( + "-i", + "--input", + type=Path, + default=DEFAULT_AGGREGATED_EXPORT_CSV, + help=f"Seed-aggregated CSV (default: {DEFAULT_AGGREGATED_EXPORT_CSV})", + ) + p.add_argument( + "-o", + "--output", + type=Path, + default=DEFAULT_LATEX_TABLE_TEX, + help=( + "Output .tex for the three-band table: Graph (transductive), Graph (inductive), " + f"Simplicial (inductive) (default: {DEFAULT_LATEX_TABLE_TEX})" + ), + ) + p.add_argument( + "--output-without-transductive", + type=Path, + default=DEFAULT_LATEX_TABLE_TEX_NO_TRANS, + help=( + "Second .tex: no cocitation Cora/Citeseer/PubMed; band titles Graph / Simplicial " + f"(no '(inductive)') (default: {DEFAULT_LATEX_TABLE_TEX_NO_TRANS})" + ), + ) + p.add_argument( + "--stdout", + action="store_true", + help=( + "Print LaTeX to stdout (three-band first, then a comment separator, then two-band " + "if that version has at least one column)" + ), + ) + p.add_argument( + "--datasets", + nargs="*", + default=None, + metavar="PATH:HEADER", + help=( + "Dataset columns as path or path:LaTeX header. " + "Default: DATASETS from main_loader.py, reordered to " + "transductive graph (cocitation cora/citeseer/pubmed) → other graph → simplicial." + ), + ) + p.add_argument("--decimals", type=int, default=2, help="Decimal places for numbers (default: 2)") + p.add_argument( + "--no-scale-fractions", + action="store_true", + help="Do not multiply accuracy/f1/... by 100 for display (W&B is often 0–1).", + ) + p.add_argument( + "--skip-submodel-tables", + action="store_true", + help="Do not emit submodel-split tables (GNN by transforms; HOPSE-M-F / HOPSE-M-PE).", + ) + p.add_argument( + "-o-sub", + "--output-submodels", + type=Path, + default=DEFAULT_LATEX_TABLE_TEX_SUBMODELS, + help=f"Submodel three-band .tex (default: {DEFAULT_LATEX_TABLE_TEX_SUBMODELS})", + ) + p.add_argument( + "--output-without-transductive-submodels", + type=Path, + default=DEFAULT_LATEX_TABLE_TEX_NO_TRANS_SUBMODELS, + help=( + "Submodel two-band .tex (no cocitation trio) " + f"(default: {DEFAULT_LATEX_TABLE_TEX_NO_TRANS_SUBMODELS})" + ), + ) + p.add_argument( + "--group-by", + metavar="COL", + nargs="+", + default=["model", "dataset"], + help=( + "Columns for best-val hyperparameter pick (default: model dataset). " + "Must include both model and dataset; same meaning as ``main_plot --group-by``." + ), + ) + args = p.parse_args() + + group_cols = tuple(args.group_by) + df = load_wandb_export_csv(args.input) + stats = collect_winner_test_by_model_dataset(df, group_cols=group_cols) + stats_sub = ( + collect_winner_test_by_submodel(df) if not args.skip_submodel_tables else {} + ) + + if args.datasets: + base_specs = _parse_dataset_specs(args.datasets) + else: + base_specs = _specs_from_loader_paths() + + groups_three = partition_specs_three_way(base_specs) + groups_two = partition_specs_two_way_no_transductive(base_specs) + n_two_cols = sum(len(b) for _, b in groups_two) + + graph_rows_base: list[tuple[str, str]] = [ + ("graph/gcn", "GCN"), + ("graph/gat", "GAT"), + ("graph/gin", "GIN"), + ] + simplicial_rows_base: list[tuple[str, str]] = [ + ("simplicial/hopse_m", "\\textbf{HOPSE-M} (Our)"), + ("simplicial/hopse_g", "\\textbf{HOPSE-G} (Our)"), + ("simplicial/topotune", "TopoTune"), + ("simplicial/sccnn_custom", "SCCNN"), + ] + cell_rows_base: list[tuple[str, str]] = [ + ("cell/hopse_m", "\\textbf{HOPSE-M} (Our)"), + ("cell/hopse_g", "\\textbf{HOPSE-G} (Our)"), + ("cell/topotune", "TopoTune"), + ("cell/cwn", "CWN"), + ] + + graph_rows_sub = graph_submodel_table_rows(stats_sub) + simplicial_rows_sub = simplicial_submodel_table_rows() + cell_rows_sub = cell_submodel_table_rows() + + tex_three = build_latex_table( + stats, + column_groups=groups_three, + graph_rows=graph_rows_base, + simplicial_rows=simplicial_rows_base, + cell_rows=cell_rows_base, + decimals=args.decimals, + scale_fraction_metrics=not args.no_scale_fractions, + label="tbl:hopse_wandb_graph_trans_ind_sim", + ) + tex_two: str | None = None + if n_two_cols > 0: + tex_two = build_latex_table( + stats, + column_groups=groups_two, + graph_rows=graph_rows_base, + simplicial_rows=simplicial_rows_base, + cell_rows=cell_rows_base, + decimals=args.decimals, + scale_fraction_metrics=not args.no_scale_fractions, + label="tbl:hopse_wandb_graph_ind_sim", + ) + + tex_three_sub: str | None = None + tex_two_sub: str | None = None + if not args.skip_submodel_tables: + tex_three_sub = build_latex_table( + stats_sub, + column_groups=groups_three, + graph_rows=graph_rows_sub, + simplicial_rows=simplicial_rows_sub, + cell_rows=cell_rows_sub, + decimals=args.decimals, + scale_fraction_metrics=not args.no_scale_fractions, + label="tbl:hopse_wandb_graph_trans_ind_sim_sub", + ) + if n_two_cols > 0: + tex_two_sub = build_latex_table( + stats_sub, + column_groups=groups_two, + graph_rows=graph_rows_sub, + simplicial_rows=simplicial_rows_sub, + cell_rows=cell_rows_sub, + decimals=args.decimals, + scale_fraction_metrics=not args.no_scale_fractions, + label="tbl:hopse_wandb_graph_ind_sim_sub", + ) + + if args.stdout: + sys.stdout.write(tex_three) + if tex_two is not None: + sys.stdout.write( + "\n% --- version without transductive graph (cocitation cora/citeseer/pubmed) ---\n\n" + ) + sys.stdout.write(tex_two) + if tex_three_sub is not None: + sys.stdout.write( + "\n% --- submodels: GNN by transforms; HOPSE-M-F / HOPSE-M-PE by encodings ---\n\n" + ) + sys.stdout.write(tex_three_sub) + if tex_two_sub is not None: + sys.stdout.write( + "\n% --- submodels without transductive graph columns ---\n\n" + ) + sys.stdout.write(tex_two_sub) + else: + out1 = Path(args.output) + out1.parent.mkdir(parents=True, exist_ok=True) + out1.write_text(tex_three, encoding="utf-8") + print(f"Wrote {out1}") + if tex_two is not None: + out2 = Path(args.output_without_transductive) + out2.parent.mkdir(parents=True, exist_ok=True) + out2.write_text(tex_two, encoding="utf-8") + print(f"Wrote {out2}") + else: + print( + "Skipped second table (--output-without-transductive): " + "no columns left after dropping transductive graph datasets." + ) + if tex_three_sub is not None: + out_s1 = Path(args.output_submodels) + out_s1.parent.mkdir(parents=True, exist_ok=True) + out_s1.write_text(tex_three_sub, encoding="utf-8") + print(f"Wrote {out_s1}") + if tex_two_sub is not None: + out_s2 = Path(args.output_without_transductive_submodels) + out_s2.parent.mkdir(parents=True, exist_ok=True) + out_s2.write_text(tex_two_sub, encoding="utf-8") + print(f"Wrote {out_s2}") + elif not args.skip_submodel_tables and n_two_cols == 0: + print( + "Skipped submodel two-band table: no columns left after dropping transductive graph datasets." + ) + + +if __name__ == "__main__": + main() diff --git a/scripts/hopse_plotting/timing_table_generator.py b/scripts/hopse_plotting/timing_table_generator.py new file mode 100644 index 000000000..8b83839ce --- /dev/null +++ b/scripts/hopse_plotting/timing_table_generator.py @@ -0,0 +1,647 @@ +#!/usr/bin/env python3 +""" +Build a LaTeX timing table from the W&B ``best_runs_rerun`` project. + +For each (model, dataset) pair, assumes one rerun (dedupes if multiple). Reads +``AvgTime/train_epoch_mean`` and ``AvgTime/train_epoch_std`` from each run's +summary. Emits a LaTeX table with: + +- **bestgray** + bold: lowest mean time per epoch in the column (ties share style). +- **stdblue**: not significantly different from the column best at 95% confidence + (two-sided Z on independent means; SE = std / sqrt(n_seeds), assumed n_seeds=10). + +Model blocks: Graph (GCN/GAT/GIN), Simplicial (HOPSE-M/HOPSE-G/TopoTune), +Cell (same trio). Dataset columns use the same ordering as ``table_generator.py``. + +Usage:: + + python scripts/hopse_plotting/timing_table_generator.py + python scripts/hopse_plotting/timing_table_generator.py --entity your-entity + python scripts/hopse_plotting/timing_table_generator.py -o tables/timing_table.tex + python scripts/hopse_plotting/timing_table_generator.py --stdout +""" + +from __future__ import annotations + +import argparse +import math +import os +import sys +from collections import defaultdict +from pathlib import Path +from typing import Any + +try: + import wandb +except ImportError: + print("Error: wandb package required. Install with: pip install wandb", file=sys.stderr) + sys.exit(1) + +try: + import pandas as pd +except ImportError: + print("Error: pandas package required. Install with: pip install pandas", file=sys.stderr) + sys.exit(1) + +from main_loader import DATASETS as LOADER_DATASETS +from utils import ( + TABLES_DIR, + flatten_config, + run_with_wandb_retry, + safe_filename_token, +) + +# ----------------------------------------------------------------------------- +# Constants +# ----------------------------------------------------------------------------- + +WANDB_ENTITY_DEFAULT = "gbg141-hopse" +WANDB_PROJECT_RERUNS = "best_runs_rerun" + +DEFAULT_TIMING_TABLE_TEX = TABLES_DIR / "timing_table_all.tex" +DEFAULT_TIMING_TABLE_TEX_NO_TRANS = TABLES_DIR / "timing_table_no_transductive.tex" + +# Same as table_generator.py +TRANSDUCTIVE_GRAPH_PATHS: tuple[str, ...] = ( + "graph/cocitation_cora", + "graph/cocitation_citeseer", + "graph/cocitation_pubmed", +) +TRANSDUCTIVE_GRAPH_SET: frozenset[str] = frozenset(TRANSDUCTIVE_GRAPH_PATHS) + +Z_CRIT_95 = 1.959963984540054 + +# Assumed number of seeds for SE calculation (std / sqrt(n_seeds)) +N_SEEDS_ASSUMED = 10 + +# ----------------------------------------------------------------------------- +# Dataset column headers (same as table_generator.py) +# ----------------------------------------------------------------------------- + +_DATASET_COLUMN_LABEL: dict[str, str] = { + "graph/MUTAG": "MUTAG", + "graph/cocitation_cora": "Cora", + "graph/PROTEINS": "PROTEINS", + "graph/NCI1": "NCI1", + "graph/NCI109": "NCI109", + "graph/cocitation_citeseer": "Citeseer", + "graph/cocitation_pubmed": "PubMed", + "simplicial/mantra_name": "NAME", + "simplicial/mantra_orientation": "ORIENT", + "simplicial/mantra_betti_numbers": r"$\beta$", + "graph/BBB_Martins": "BBB", + "graph/CYP3A4_Veith": "CYP3A4", + "graph/Clearance_Hepatocyte_AZ": "Cl.Hep.", + "graph/Caco2_Wang": "Caco2", +} + + +def _latex_short_dataset_label(path: str) -> str: + return _DATASET_COLUMN_LABEL.get(path, path.rsplit("/", 1)[-1].replace("_", r"\_")) + + +def _auto_header_for_dataset_path(path: str) -> str: + short = _latex_short_dataset_label(path) + # For timing, lower is always better + return f"{short} " + r"($\downarrow$)" + + +def _specs_from_loader_paths() -> list[tuple[str, str]]: + return [ + (p.strip(), _auto_header_for_dataset_path(p.strip())) + for p in LOADER_DATASETS + if p.strip() + ] + + +def partition_specs_three_way( + specs: list[tuple[str, str]], +) -> list[tuple[str, list[tuple[str, str]]]]: + """ + Graph (transductive) → Graph (inductive) → Simplicial (inductive). + """ + by_path = {p: h for p, h in specs} + trans = [(p, by_path[p]) for p in TRANSDUCTIVE_GRAPH_PATHS if p in by_path] + graph_ind: list[tuple[str, str]] = [] + simplicial: list[tuple[str, str]] = [] + for p, h in specs: + if p in TRANSDUCTIVE_GRAPH_SET: + continue + if p.startswith("graph/"): + graph_ind.append((p, h)) + elif p.startswith("simplicial/"): + simplicial.append((p, h)) + blocks = [ + ("Graph (transductive)", trans), + ("Graph (inductive)", graph_ind), + ("Simplicial (inductive)", simplicial), + ] + return [(title, blk) for title, blk in blocks if blk] + + +def partition_specs_two_way_no_transductive( + specs: list[tuple[str, str]], +) -> list[tuple[str, list[tuple[str, str]]]]: + """Graph then Simplicial; omits cocitation Cora/Citeseer/PubMed.""" + graph_ind: list[tuple[str, str]] = [] + simplicial: list[tuple[str, str]] = [] + for p, h in specs: + if p in TRANSDUCTIVE_GRAPH_SET: + continue + if p.startswith("graph/"): + graph_ind.append((p, h)) + elif p.startswith("simplicial/"): + simplicial.append((p, h)) + blocks = [ + ("Graph", graph_ind), + ("Simplicial", simplicial), + ] + return [(title, blk) for title, blk in blocks if blk] + + +# ----------------------------------------------------------------------------- +# Statistical helpers +# ----------------------------------------------------------------------------- + + +def _finite(x: Any) -> bool: + try: + v = float(x) + except (TypeError, ValueError): + return False + return math.isfinite(v) + + +def _sem(std: float, n: int) -> float: + if n <= 0 or not _finite(std): + return 0.0 + return float(std) / math.sqrt(float(n)) + + +def _z_two_sample(mu_i: float, se_i: float, mu_j: float, se_j: float) -> float: + v = se_i * se_i + se_j * se_j + if v <= 0.0: + return 0.0 if abs(mu_i - mu_j) < 1e-12 else float("inf") + return abs(mu_i - mu_j) / math.sqrt(v) + + +def _not_sig_diff_from_best(mu: float, se: float, best_mu: float, best_se: float) -> bool: + return _z_two_sample(mu, se, best_mu, best_se) <= Z_CRIT_95 + + +# ----------------------------------------------------------------------------- +# W&B data collection +# ----------------------------------------------------------------------------- + + +def collect_timing_data( + entity: str, + project: str, + verbose: bool = True, +) -> dict[tuple[str, str], tuple[float, float]]: + """ + Fetch timing data from W&B. + + Returns: + dict[(model, dataset)] = (mean_time, std_time) + """ + api = wandb.Api(timeout=60) + + def _fetch(): + return api.runs( + f"{entity}/{project}", + filters={"state": "finished"}, + ) + + if verbose: + print(f"Fetching runs from {entity}/{project} ...") + + runs = run_with_wandb_retry(_fetch, label=f"{entity}/{project}") + + # Group by (model, dataset) + runs_by_key: dict[tuple[str, str], list[Any]] = defaultdict(list) + + for run in runs: + # Parse model and dataset from run name: domain__model__domain__dataset + # e.g., "graph__gat__graph__BBB_Martins" -> model="graph/gat", dataset="graph/BBB_Martins" + run_name = run.name + parts = run_name.split("__") + + if len(parts) < 4: + if verbose: + print(f" (skip) run {run.name}: unexpected name format (expected domain__model__domain__dataset)") + continue + + # Reconstruct with slashes: domain/model and domain/dataset + model_domain = parts[0] + model_name = parts[1] + dataset_domain = parts[2] + dataset_name = "__".join(parts[3:]) # In case dataset has __ in it + + model = f"{model_domain}/{model_name}" + dataset = f"{dataset_domain}/{dataset_name}" + + runs_by_key[(model, dataset)].append(run) + + # Extract timing data + timing_data: dict[tuple[str, str], tuple[float, float]] = {} + + # Track what keys we've seen for diagnostics + timing_keys_seen: set[str] = set() + first_run_debugged = False + + for (model, dataset), run_list in runs_by_key.items(): + if len(run_list) > 1: + if verbose: + print( + f" (warn) Multiple runs for ({model}, {dataset}): " + f"{len(run_list)} runs. Using first." + ) + + run = run_list[0] + + # AvgTime/* is logged as a metric (W&B summary). Rerun projects often have an + # empty run.config; utils.run_to_row reads metrics via dict(run.summary). + summary = dict(run.summary) + flat_config = flatten_config(dict(run.config)) + metrics: dict[str, Any] = {**flat_config, **summary} + + if not first_run_debugged and verbose: + sk = sorted(summary.keys()) + timing_in_summary = [ + k for k in sk if "time" in k.lower() or "Time" in k or "AvgTime" in k + ] + print( + f"\nDEBUG: first run '{run.name}': " + f"{len(summary)} summary keys, {len(flat_config)} flattened config keys." + ) + if timing_in_summary: + print(f" Timing-related summary keys: {timing_in_summary}") + else: + print(f" No timing keys in summary; first 15 summary keys: {sk[:15]}") + first_run_debugged = True + + for key in summary.keys(): + if "time" in key.lower() or "Time" in key or "AvgTime" in key: + timing_keys_seen.add(key) + + mean_key = "AvgTime/train_epoch_mean" + std_key = "AvgTime/train_epoch_std" + mean_time = metrics.get(mean_key) + std_time = metrics.get(std_key) + + if mean_time is None or std_time is None: + if verbose: + print( + f" (skip) ({model}, {dataset}): " + f"missing timing data (mean={mean_time}, std={std_time})" + ) + continue + + if not _finite(mean_time) or not _finite(std_time): + if verbose: + print( + f" (skip) ({model}, {dataset}): " + f"non-finite timing data (mean={mean_time}, std={std_time})" + ) + continue + + timing_data[(model, dataset)] = (float(mean_time), float(std_time)) + + if verbose: + print(f"\n{'-'*70}") + if timing_keys_seen: + print(f"Timing-related keys found across all runs: {sorted(timing_keys_seen)}") + else: + print("WARNING: No timing-related keys found in any run summary!") + print(f"Collected timing data for {len(timing_data)} (model, dataset) pairs") + print(f"{'-'*70}\n") + + return timing_data + + +# ----------------------------------------------------------------------------- +# LaTeX table building +# ----------------------------------------------------------------------------- + + +def _format_time(seconds: float, decimals: int = 2) -> str: + """Format time in seconds, use scientific notation if >= 1000.""" + if seconds >= 1000: + return f"{seconds:.{decimals}e}" + return f"{seconds:.{decimals}f}" + + +def _make_cell( + mean: float, + std: float, + is_best: bool, + is_stat_tied: bool, + decimals: int, +) -> str: + """ + Build a LaTeX table cell with color/bold based on statistical ranking. + + - bestgray + bold: is_best + - stdblue: is_stat_tied (but not best) + - plain: neither + """ + mean_str = _format_time(mean, decimals) + std_str = _format_time(std, decimals) + content = f"${mean_str} \\pm {std_str}$" + + if is_best: + return f"\\cellcolor{{bestgray}}\\textbf{{{content}}}" + elif is_stat_tied: + return f"\\cellcolor{{stdblue}}{content}" + else: + return content + + +def build_latex_table( + timing_data: dict[tuple[str, str], tuple[float, float]], + column_groups: list[tuple[str, list[tuple[str, str]]]], + graph_rows: list[tuple[str, str]], + simplicial_rows: list[tuple[str, str]], + cell_rows: list[tuple[str, str]], + decimals: int = 2, + label: str = "tbl:timing", +) -> str: + """ + Build the full LaTeX table. + + Args: + timing_data: dict[(model, dataset)] = (mean_time, std_time) + column_groups: List of (group_title, [(dataset_path, header), ...]) + graph_rows: List of (model_path, latex_label) for graph models + simplicial_rows: List of (model_path, latex_label) for simplicial models + cell_rows: List of (model_path, latex_label) for cell models + decimals: Decimal places for numbers + label: LaTeX label + """ + lines = [] + + # Preamble + lines.append("\\begin{table}[t]") + lines.append("\\centering") + lines.append("\\begin{adjustbox}{max width=\\textwidth}") + + # Count total columns + total_cols = sum(len(cols) for _, cols in column_groups) + + # Table header + col_spec = "l" + "c" * total_cols + lines.append(f"\\begin{{tabular}}{{{col_spec}}}") + lines.append("\\toprule") + + # Build header rows + header_row1 = ["\\textbf{Model}"] + header_row2 = [""] + + for group_title, cols in column_groups: + if cols: + header_row1.append(f"\\multicolumn{{{len(cols)}}}{{c}}{{\\textbf{{{group_title}}}}}") + header_row2.extend([header for _, header in cols]) + + lines.append(" & ".join(header_row1) + " \\\\") + if len(column_groups) > 1: + # Add cmidrule for each group + col_idx = 2 # Start after Model column + for _, cols in column_groups: + if cols: + lines.append(f"\\cmidrule(lr){{{col_idx}-{col_idx + len(cols) - 1}}}") + col_idx += len(cols) + lines.append(" & ".join(header_row2) + " \\\\") + lines.append("\\midrule") + + # Collect all dataset paths in order + all_datasets: list[str] = [] + for _, cols in column_groups: + all_datasets.extend([path for path, _ in cols]) + + # Helper to emit a model block + def emit_model_block(block_title: str, rows: list[tuple[str, str]]) -> None: + if not rows: + return + + lines.append(f"\\multicolumn{{{total_cols + 1}}}{{l}}{{\\textbf{{{block_title}}}}} \\\\") + + for model_path, model_label in rows: + # For each dataset, find best and stat-tied + col_data: list[tuple[float, float] | None] = [] + for ds_path in all_datasets: + key = (model_path, ds_path) + if key in timing_data: + col_data.append(timing_data[key]) + else: + col_data.append(None) + + # Find best (minimum mean) for each column + row_cells = [model_label] + + for col_idx, data in enumerate(col_data): + ds_path = all_datasets[col_idx] + + if data is None: + row_cells.append("---") + continue + + mean, std = data + se = _sem(std, N_SEEDS_ASSUMED) + + # Find best in this column across all models + best_mean = float("inf") + best_se = 0.0 + + for all_model_path, _ in graph_rows + simplicial_rows + cell_rows: + key = (all_model_path, ds_path) + if key in timing_data: + m, s = timing_data[key] + if _finite(m) and m < best_mean: + best_mean = m + best_se = _sem(s, N_SEEDS_ASSUMED) + + is_best = abs(mean - best_mean) < 1e-12 + is_stat_tied = ( + not is_best + and _finite(best_mean) + and _not_sig_diff_from_best(mean, se, best_mean, best_se) + ) + + cell = _make_cell(mean, std, is_best, is_stat_tied, decimals) + row_cells.append(cell) + + lines.append(" & ".join(row_cells) + " \\\\") + + # Emit model blocks + emit_model_block("Graph", graph_rows) + lines.append("\\midrule") + emit_model_block("Simplicial", simplicial_rows) + lines.append("\\midrule") + emit_model_block("Cell", cell_rows) + + lines.append("\\bottomrule") + lines.append("\\end{tabular}") + lines.append("\\end{adjustbox}") + lines.append( + f"\\caption{{Training time per epoch (seconds, mean $\\pm$ std). " + f"\\cellcolor{{bestgray}}\\textbf{{Bold}}: best (lowest) time; " + f"\\cellcolor{{stdblue}}Blue: not significantly worse than best (95\\% CI).}}" + ) + lines.append(f"\\label{{{label}}}") + lines.append("\\end{table}") + + return "\n".join(lines) + "\n" + + +# ----------------------------------------------------------------------------- +# Main +# ----------------------------------------------------------------------------- + + +def main() -> None: + parser = argparse.ArgumentParser( + description="Generate LaTeX timing table from W&B best_runs_rerun project." + ) + parser.add_argument( + "--entity", + default=WANDB_ENTITY_DEFAULT, + help=f"W&B entity (default: {WANDB_ENTITY_DEFAULT})", + ) + parser.add_argument( + "--project", + default=WANDB_PROJECT_RERUNS, + help=f"W&B project name (default: {WANDB_PROJECT_RERUNS})", + ) + parser.add_argument( + "-o", + "--output", + type=Path, + default=DEFAULT_TIMING_TABLE_TEX, + help=( + f"Output .tex for three-band table (transductive + inductive) " + f"(default: {DEFAULT_TIMING_TABLE_TEX})" + ), + ) + parser.add_argument( + "--output-without-transductive", + type=Path, + default=DEFAULT_TIMING_TABLE_TEX_NO_TRANS, + help=( + f"Second .tex: no transductive datasets " + f"(default: {DEFAULT_TIMING_TABLE_TEX_NO_TRANS})" + ), + ) + parser.add_argument( + "--stdout", + action="store_true", + help="Print LaTeX to stdout instead of writing files", + ) + parser.add_argument( + "--decimals", + type=int, + default=2, + help="Decimal places for timing values (default: 2)", + ) + parser.add_argument( + "--quiet", + action="store_true", + help="Less console output", + ) + + args = parser.parse_args() + + # Check for WANDB_API_KEY + if not os.environ.get("WANDB_API_KEY"): + print( + "Warning: WANDB_API_KEY not set. Run 'wandb login' or set the environment variable.", + file=sys.stderr, + ) + + # Collect timing data + timing_data = collect_timing_data( + args.entity, + args.project, + verbose=not args.quiet, + ) + + if not timing_data: + print("Error: No timing data found. Check entity/project names.", file=sys.stderr) + sys.exit(1) + + # Build dataset specs + base_specs = _specs_from_loader_paths() + + # Create both table variants + groups_three = partition_specs_three_way(base_specs) + groups_two = partition_specs_two_way_no_transductive(base_specs) + n_two_cols = sum(len(b) for _, b in groups_two) + + # Model rows (same as table_generator.py) + graph_rows = [ + ("graph/gcn", "GCN"), + ("graph/gat", "GAT"), + ("graph/gin", "GIN"), + ] + simplicial_rows = [ + ("simplicial/hopse_m", "\\textbf{HOPSE-M} (Our)"), + ("simplicial/hopse_g", "\\textbf{HOPSE-G} (Our)"), + ("simplicial/topotune", "TopoTune"), + ] + cell_rows = [ + ("cell/hopse_m", "\\textbf{HOPSE-M} (Our)"), + ("cell/hopse_g", "\\textbf{HOPSE-G} (Our)"), + ("cell/topotune", "TopoTune"), + ] + + # Build LaTeX tables + tex_three = build_latex_table( + timing_data, + column_groups=groups_three, + graph_rows=graph_rows, + simplicial_rows=simplicial_rows, + cell_rows=cell_rows, + decimals=args.decimals, + label="tbl:timing_all", + ) + + tex_two: str | None = None + if n_two_cols > 0: + tex_two = build_latex_table( + timing_data, + column_groups=groups_two, + graph_rows=graph_rows, + simplicial_rows=simplicial_rows, + cell_rows=cell_rows, + decimals=args.decimals, + label="tbl:timing_no_transductive", + ) + + # Output + if args.stdout: + sys.stdout.write(tex_three) + if tex_two is not None: + sys.stdout.write( + "\n% --- version without transductive graph (cocitation cora/citeseer/pubmed) ---\n\n" + ) + sys.stdout.write(tex_two) + else: + out1 = Path(args.output) + out1.parent.mkdir(parents=True, exist_ok=True) + out1.write_text(tex_three, encoding="utf-8") + print(f"Wrote {out1}") + + if tex_two is not None: + out2 = Path(args.output_without_transductive) + out2.parent.mkdir(parents=True, exist_ok=True) + out2.write_text(tex_two, encoding="utf-8") + print(f"Wrote {out2}") + else: + print( + "Skipped second table (--output-without-transductive): " + "no columns left after dropping transductive datasets." + ) + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/scripts/hopse_plotting/utils.py b/scripts/hopse_plotting/utils.py new file mode 100644 index 000000000..ffc09e714 --- /dev/null +++ b/scripts/hopse_plotting/utils.py @@ -0,0 +1,1120 @@ +""" +Shared helpers for W&B TopoBench export CSVs: config constants, API helpers, +flattening, and aggregation of runs across data seeds. + +Default filesystem layout (under ``scripts/hopse_plotting/``): + +- ``csvs/`` — monolithic export, seed-aggregated CSV, collapsed CSV +- ``csvs/hopse_experiments_wandb_export_shards/`` — per-model or per-dataset shards from ``main_loader`` +- ``plots/leaderboard/`` — collapse / leaderboard figures (``main_plot``) +- ``plots/hyperparam/`` — ``model/dataset/`` trees from ``hyperparam_analysis`` +- ``tables/`` — LaTeX from ``table_generator`` +""" + +from __future__ import annotations + +import json +import random +import re +import time +from collections.abc import Mapping +from pathlib import Path +from typing import Any, Literal + +import pandas as pd + +# ----------------------------------------------------------------------------- +# Package paths (scripts/hopse_plotting) +# ----------------------------------------------------------------------------- + +_PLOT_PACKAGE_ROOT = Path(__file__).resolve().parent +CSV_DIR = _PLOT_PACKAGE_ROOT / "csvs" +WANDB_EXPORT_SHARDS_SUBDIR = "hopse_experiments_wandb_export_shards" +PLOTS_DIR = _PLOT_PACKAGE_ROOT / "plots" +TABLES_DIR = _PLOT_PACKAGE_ROOT / "tables" + +DEFAULT_WANDB_EXPORT_CSV = CSV_DIR / "hopse_experiments_wandb_export.csv" +DEFAULT_WANDB_EXPORT_SHARD_DIR = CSV_DIR / WANDB_EXPORT_SHARDS_SUBDIR +DEFAULT_AGGREGATED_EXPORT_CSV = CSV_DIR / "hopse_experiments_wandb_export_seed_agg.csv" +DEFAULT_COLLAPSED_EXPORT_CSV = CSV_DIR / "hopse_experiments_wandb_export_collapsed.csv" +DEFAULT_HYPERPARAM_PLOT_DIR = PLOTS_DIR / "hyperparam" +DEFAULT_LEADERBOARD_PLOT_DIR = PLOTS_DIR / "leaderboard" + +# ----------------------------------------------------------------------------- +# Column layout (must match main_loader export CSV columns) +# ----------------------------------------------------------------------------- + +CONFIG_PARAM_KEYS: list[str] = [ + "model", + "dataset", + "transforms", + "transforms.CombinedPSEs.encodings", + "transforms.CombinedFEs.encodings", + # SANN sweeps (``scripts/sann.sh``): k-hop transform + backbone/complex dims. + "transforms.sann_encoding.max_hop", + "transforms.sann_encoding.complex_dim", + "transforms.sann_encoding.max_rank", + # HOPSE_G / GPSE (``scripts/hopse_g.sh``): without ``pretrain_model``, molpcba vs zinc + # runs merge in seed aggregation (2 checkpoints × 5 seeds → ``n_seeds==10``). + "transforms.hopse_encoding.pretrain_model", + "transforms.hopse_encoding.neighborhoods", + "transforms.hopse_encoding.max_hop", + "transforms.hopse_encoding.max_rank", + "transforms.hopse_encoding.complex_dim", + "model.feature_encoder.selected_dimensions", + "model.backbone.complex_dim", + "model.preprocessing_params.neighborhoods", + "model.preprocessing_params.encodings", + "model.backbone.neighborhoods", + "model.backbone.num_layers", + "model.backbone.n_layers", + "model.backbone.GNN.num_layers", + "model.feature_encoder.out_channels", + "model.feature_encoder.proj_dropout", + "optimizer.parameters.lr", + "optimizer.parameters.weight_decay", + "dataset.dataloader_params.batch_size", + "dataset.split_params.data_seed", + "dataset.parameters.monitor_metric", +] + +META_COLUMNS: list[str] = [ + "wandb_entity", + "wandb_project", + "run_state", + "identifiers_run_id", + "identifiers_run_name", + "identifiers_run_url", + "identifiers_tags", +] + +SEED_COLUMN = "dataset.split_params.data_seed" +MONITOR_METRIC_COLUMN = "dataset.parameters.monitor_metric" +IDENTIFIER_COLUMN_PREFIX = "identifiers_" +SUMMARY_COLUMN_PREFIX = "summary_" + +# Hydra / PyTorch often expect ints for these; CSV aggregation yields floats (e.g. 1.0) and breaks +# e.g. torch_geometric GAT: range(num_layers - 2). +HYDRA_WHOLE_NUMBER_OVERRIDE_KEYS: frozenset[str] = frozenset( + { + "model.backbone.num_layers", + "model.backbone.n_layers", + "model.backbone.GNN.num_layers", + "model.backbone.complex_dim", + "model.feature_encoder.out_channels", + "transforms.sann_encoding.max_hop", + "transforms.sann_encoding.complex_dim", + "transforms.sann_encoding.max_rank", + "transforms.hopse_encoding.max_hop", + "transforms.hopse_encoding.max_rank", + "transforms.hopse_encoding.complex_dim", + "dataset.dataloader_params.batch_size", + "dataset.split_params.data_seed", + } +) + +# W&B CSV cells often store OmegaConf/JSON lists as ``["a","b"]``; sweep scripts pass +# bracket lists without inner quotes (``[a,b]``). Normalize for CLI reproducibility. +HYDRA_JSON_LIST_TO_BRACKET_KEYS: frozenset[str] = frozenset( + { + "transforms.CombinedPSEs.encodings", + "transforms.CombinedFEs.encodings", + "model.preprocessing_params.neighborhoods", + "model.preprocessing_params.encodings", + "transforms.hopse_encoding.neighborhoods", + "model.backbone.neighborhoods", + "model.feature_encoder.selected_dimensions", + } +) + +# W&B / flattened configs record ``dataset.loader.parameters.data_name`` (Planetoid: Cora, +# not cocitation_cora). Hydra ``dataset=`` must match ``configs/dataset/`` without +# ``.yaml``. Add rows here when a loader identity does not equal that path. +# Omitted: ``graph/ZINC`` maps to both ``graph/ZINC`` and ``graph/ZINC_OGB`` (same data_name). +DATASET_LOADER_IDENTITY_TO_HYDRA: dict[str, str] = { + "graph/Cora": "graph/cocitation_cora", + "graph/citeseer": "graph/cocitation_citeseer", + "graph/PubMed": "graph/cocitation_pubmed", + "graph/manual": "graph/manual_dataset", + "hypergraph/20newsW100": "hypergraph/20newsgroup", + "simplicial/MANTRA_genus": "simplicial/mantra_genus", + "simplicial/MANTRA_name": "simplicial/mantra_name", + "simplicial/MANTRA_orientation": "simplicial/mantra_orientation", + "simplicial/MANTRA_betti_numbers": "simplicial/mantra_betti_numbers", +} + + +def hydra_dataset_key_from_loader_identity(identity: str) -> str: + """Map loader-style ``domain/data_name`` from exports to Hydra ``dataset=`` key.""" + ident = identity.replace("\r", "").strip() + if not ident: + return ident + return DATASET_LOADER_IDENTITY_TO_HYDRA.get(ident, ident) + + +# ----------------------------------------------------------------------------- +# W&B resilience helpers +# ----------------------------------------------------------------------------- + + +def wandb_transient_api_error(exc: BaseException) -> bool: + text = str(exc).lower() + markers = ( + "502", + "503", + "504", + "429", + "bad gateway", + "timed out", + "timeout", + "temporarily unavailable", + "connection reset", + ) + return any(m in text for m in markers) + + +def run_with_wandb_retry( + fn, + *, + max_retries: int = 6, + label: str = "W&B API", +): + last: BaseException | None = None + for attempt in range(max_retries): + try: + return fn() + except Exception as e: + last = e + if attempt == max_retries - 1 or not wandb_transient_api_error(e): + raise + delay = min(120.0, (2**attempt) * 10 + random.uniform(0, 3)) + print( + f"\n {label} transient error (attempt {attempt + 1}/{max_retries}): {e!s}\n" + f" Retrying in {delay:.0f}s ...\n" + ) + time.sleep(delay) + assert last is not None + raise last + + +# ----------------------------------------------------------------------------- +# Config flattening & value extraction (loader) +# ----------------------------------------------------------------------------- + + +def _unwrap_wandb_value(v: Any) -> Any: + if isinstance(v, dict) and set(v.keys()) <= {"value", "desc", "params"}: + if "value" in v: + return _unwrap_wandb_value(v["value"]) + return v + + +def flatten_config(obj: Any, parent_key: str = "", sep: str = ".") -> dict[str, Any]: + out: dict[str, Any] = {} + if not isinstance(obj, Mapping): + return {parent_key: obj} if parent_key else {} + + for raw_k, raw_v in obj.items(): + k = str(raw_k) + if not parent_key and k.startswith("_"): + continue + v = _unwrap_wandb_value(raw_v) + new_key = f"{parent_key}{sep}{k}" if parent_key else k + if isinstance(v, Mapping): + out.update(flatten_config(v, new_key, sep=sep)) + else: + out[new_key] = v + return out + + +def _serialize_cell(x: Any) -> str: + if x is None: + return "" + if isinstance(x, bool): + return "true" if x else "false" + if isinstance(x, int | float) and not isinstance(x, bool): + return repr(x) if isinstance(x, float) else str(x) + if isinstance(x, str): + return x + try: + return json.dumps(x, sort_keys=True) + except TypeError: + return str(x) + + +def get_from_flat(flat: Mapping[str, Any], dotted: str) -> Any: + """Resolve a Hydra-style dotted key from W&B ``run.config`` (after ``flatten_config``). + + Lightning/W&B sometimes flatten nested hparams with ``/`` instead of ``.``; try both so + sweep axes like ``transforms.sann_encoding.max_hop`` are not dropped (else seed + aggregation can merge 3×5 runs into one bucket with ``n_seeds==15``). + """ + if dotted in flat: + return flat[dotted] + slashy = dotted.replace(".", "/") + if slashy in flat: + return flat[slashy] + return "" + + +def _resolved_model_path(flat: Mapping[str, Any]) -> str: + direct = get_from_flat(flat, "model") + if direct not in (None, ""): + if isinstance(direct, str): + return direct + return _serialize_cell(direct) + domain = get_from_flat(flat, "model.model_domain") + name = get_from_flat(flat, "model.model_name") + if domain and name: + return f"{domain}/{name}" + return "" + + +def _resolved_dataset_path(flat: Mapping[str, Any]) -> str: + direct = get_from_flat(flat, "dataset") + if direct not in (None, ""): + if isinstance(direct, str): + return hydra_dataset_key_from_loader_identity(direct.strip()) + return hydra_dataset_key_from_loader_identity(_serialize_cell(direct)) + domain = get_from_flat(flat, "dataset.loader.parameters.data_domain") + name = get_from_flat(flat, "dataset.loader.parameters.data_name") + if domain and name: + dd = domain if isinstance(domain, str) else _serialize_cell(domain) + dn = name if isinstance(name, str) else _serialize_cell(name) + return hydra_dataset_key_from_loader_identity(f"{dd}/{dn}") + return "" + + +def _resolved_transforms_preset(flat: Mapping[str, Any]) -> str: + direct = get_from_flat(flat, "transforms") + if direct not in (None, ""): + if isinstance(direct, str): + return direct + return _serialize_cell(direct) + if get_from_flat(flat, "transforms.CombinedPSEs.encodings"): + return "combined_pe" + if get_from_flat(flat, "transforms.CombinedFEs.encodings"): + return "combined_fe" + return "" + + +def extract_config_params(flat: Mapping[str, Any]) -> dict[str, str]: + row: dict[str, str] = {} + for key in CONFIG_PARAM_KEYS: + if key == "model": + row[key] = _resolved_model_path(flat) + elif key == "dataset": + row[key] = _resolved_dataset_path(flat) + elif key == "transforms": + row[key] = _resolved_transforms_preset(flat) + else: + row[key] = _serialize_cell(get_from_flat(flat, key)) + return row + + +def summary_to_prefixed_row(summary: Mapping[str, Any]) -> dict[str, str]: + out: dict[str, str] = {} + for k, v in summary.items(): + col = f"{SUMMARY_COLUMN_PREFIX}{k}" + out[col] = _serialize_cell(v) + return out + + +def dataset_basename(dataset_path: str) -> str: + return dataset_path.rsplit("/", 1)[-1] + + +def expected_project_name(model: str, dataset_path: str) -> str: + return f"{model}_{dataset_basename(dataset_path)}" + + +def project_full_path(entity: str, project: str) -> str: + return f"{entity}/{project}" + + +def iter_runs(api, entity: str, project: str, *, state: str | None): + path = project_full_path(entity, project) + filters = {"state": state} if state else None + + def _list(): + return api.runs(path, filters=filters, per_page=500) + + return run_with_wandb_retry(_list, label=f"W&B list runs {path}") + + +def run_to_row( + *, + entity: str, + project: str, + run, +) -> dict[str, Any]: + flat = flatten_config(dict(run.config)) + meta = { + "wandb_entity": entity, + "wandb_project": project, + "run_state": run.state, + "identifiers_run_id": run.id, + "identifiers_run_name": run.name or "", + "identifiers_run_url": run.url, + "identifiers_tags": ",".join(run.tags or []), + } + params = extract_config_params(flat) + summ = summary_to_prefixed_row(dict(run.summary)) + + return {**meta, **params, **summ} + + +def collect_all_runs( + entity: str, + models: list[str], + datasets: list[str], + *, + run_state: str | None = "finished", + verbose: bool = True, +) -> list[dict[str, Any]]: + import wandb + + api = wandb.Api(timeout=120) + rows: list[dict[str, Any]] = [] + + for model in models: + for ds in datasets: + proj = expected_project_name(model, ds) + if verbose: + _filt = f"state={run_state}" if run_state else "all states" + print(f" (fetch) {entity}/{proj} ({_filt})", flush=True) + try: + runs_gen = iter_runs(api, entity, proj, state=run_state) + count = 0 + for run in runs_gen: + rows.append(run_to_row(entity=entity, project=proj, run=run)) + count += 1 + if verbose and count % 250 == 0: + print(f" … {count} run(s) so far", flush=True) + except Exception as e: + if verbose: + print(f" (skip) {e}", flush=True) + continue + if verbose: + print(f" -> {count} run(s)", flush=True) + return rows + + +def dataframe_from_rows(rows: list[dict[str, Any]]) -> pd.DataFrame: + if not rows: + return pd.DataFrame(columns=META_COLUMNS + CONFIG_PARAM_KEYS) + + df = pd.DataFrame(rows) + summary_cols = sorted(c for c in df.columns if c.startswith(SUMMARY_COLUMN_PREFIX)) + ordered = META_COLUMNS + CONFIG_PARAM_KEYS + summary_cols + rest = [c for c in df.columns if c not in ordered] + df = df[[c for c in ordered if c in df.columns] + rest] + df = df.fillna("") + return df + + +# ----------------------------------------------------------------------------- +# CSV I/O & seed aggregation +# ----------------------------------------------------------------------------- + + +def load_wandb_export_csv(path: str | Path) -> pd.DataFrame: + """Read an export CSV produced by ``main_loader``.""" + return pd.read_csv(path, low_memory=False) + + +def is_seed_aggregatable_summary_column(name: str) -> bool: + """ + Summary columns to keep when aggregating over seeds: metrics whose W&B key + path mentions train, val (including val_best_rerun), or test_best_rerun. + """ + if not name.startswith(SUMMARY_COLUMN_PREFIX): + return False + tail = name[len(SUMMARY_COLUMN_PREFIX) :] + if "train/" in tail or "/train/" in tail: + return True + if "val/" in tail or "/val/" in tail: + return True + if "test_best_rerun/" in tail: + return True + return False + + +def list_seed_aggregatable_summary_columns(df: pd.DataFrame) -> list[str]: + cols = [c for c in df.columns if is_seed_aggregatable_summary_column(c)] + return sorted(cols) + + +def hyperparam_groupby_columns(df: pd.DataFrame) -> list[str]: + """All columns except identifiers, summary_*, and the data seed.""" + out: list[str] = [] + for c in df.columns: + if c.startswith(IDENTIFIER_COLUMN_PREFIX): + continue + if c.startswith(SUMMARY_COLUMN_PREFIX): + continue + if c == SEED_COLUMN: + continue + out.append(c) + return out + + +def aggregate_wandb_export_by_seed( + df: pd.DataFrame, + *, + seed_column: str = SEED_COLUMN, + summary_metric_columns: list[str] | None = None, +) -> pd.DataFrame: + """ + One row per hyperparameter setting (everything equal except identifiers, + summary columns, and ``seed_column``). + + For each group, ``n_seeds`` is the run count. Selected summary metrics + (train/..., val/..., test_best_rerun/...) get ``__mean`` and + ``__std`` (``std`` uses ``ddof=0`` so a single seed yields 0). + + Rows should include ``dataset.parameters.monitor_metric`` (from the loader) + for downstream collapse / reporting. + + All raw ``summary_*`` columns are dropped from the output; identifier and + seed columns are dropped. Non-aggregated context (e.g. wandb_entity) is kept. + """ + missing = [c for c in (seed_column,) if c not in df.columns] + if missing: + raise KeyError(f"CSV missing expected column(s): {missing}") + + df = df.copy() + if MONITOR_METRIC_COLUMN not in df.columns: + df[MONITOR_METRIC_COLUMN] = "" + + group_cols = hyperparam_groupby_columns(df) + if summary_metric_columns is None: + summary_metric_columns = list_seed_aggregatable_summary_columns(df) + + unknown = [c for c in summary_metric_columns if c not in df.columns] + if unknown: + raise KeyError(f"Unknown summary column(s): {unknown}") + + sub = df[group_cols].copy() + for c in summary_metric_columns: + sub[c] = pd.to_numeric(df[c], errors="coerce") + + g = sub.groupby(group_cols, dropna=False) + n_seeds = g.size().rename("n_seeds") + + mean_df = g[summary_metric_columns].mean() + mean_df.columns = [f"{c}__mean" for c in mean_df.columns] + + std_df = g[summary_metric_columns].std(ddof=0) + std_df.columns = [f"{c}__std" for c in std_df.columns] + + # ``n_seeds`` is a Series; concat with empty metric frames (no summary cols / no rows) + # must be 2D+2D or pandas 2.x raises "unaligned mixed dimensional NDFrame objects". + out = pd.concat([n_seeds.to_frame(), mean_df, std_df], axis=1).reset_index() + + # Stable metric column order: sort by base summary name, mean then std each + metric_sorted = sorted(summary_metric_columns) + tail = [] + for c in metric_sorted: + tail.append(f"{c}__mean") + tail.append(f"{c}__std") + + ordered = group_cols + ["n_seeds"] + tail + out = out[[c for c in ordered if c in out.columns]] + return out + + +def build_seed_bucket_report( + aggregated_df: pd.DataFrame, + *, + model_col: str = "model", + dataset_col: str = "dataset", + n_seeds_col: str = "n_seeds", +) -> pd.DataFrame: + """ + Count hyperparameter groups (rows of a seed-aggregated frame) by how many + raw runs were merged per group, broken down by (model, dataset). + + Columns: ``model``, ``dataset``, ``n_seeds``, ``n_groups``, + ``pct_of_groups`` (percent of groups within that model+dataset, 0--100). + """ + if aggregated_df.empty: + return pd.DataFrame( + columns=[model_col, dataset_col, n_seeds_col, "n_groups", "pct_of_groups"] + ) + missing = [c for c in (model_col, dataset_col, n_seeds_col) if c not in aggregated_df.columns] + if missing: + raise KeyError(f"seed bucket report: missing column(s): {missing}") + + work = aggregated_df[[model_col, dataset_col, n_seeds_col]].copy() + work[n_seeds_col] = pd.to_numeric(work[n_seeds_col], errors="coerce").astype("Int64") + counts = ( + work.groupby([model_col, dataset_col, n_seeds_col], dropna=False) + .size() + .rename("n_groups") + .reset_index() + ) + totals = counts.groupby([model_col, dataset_col], dropna=False)["n_groups"].transform("sum") + counts["pct_of_groups"] = (counts["n_groups"] / totals * 100.0).round(2) + return counts.sort_values([model_col, dataset_col, n_seeds_col]).reset_index(drop=True) + + +def filter_aggregated_to_required_n_seeds( + aggregated_df: pd.DataFrame, + required_n_seeds: int, + *, + n_seeds_col: str = "n_seeds", +) -> pd.DataFrame: + """Keep only hyperparameter groups with exactly ``required_n_seeds`` runs.""" + if n_seeds_col not in aggregated_df.columns: + raise KeyError(f"filter aggregated: missing {n_seeds_col!r}") + ns = pd.to_numeric(aggregated_df[n_seeds_col], errors="coerce") + return aggregated_df.loc[ns == required_n_seeds].copy() + + +def aggregate_wandb_export_csv( + input_path: str | Path, + output_path: str | Path, + *, + summary_metric_columns: list[str] | None = None, + required_n_seeds: int | None = None, +) -> tuple[pd.DataFrame, pd.DataFrame]: + """ + Load export CSV, aggregate by seed, optionally keep only groups with an + exact run count, write ``output_path``. + + Returns ``(written_frame, seed_bucket_report)`` where the report is built + from the aggregate **before** filtering on ``required_n_seeds``. + """ + df = load_wandb_export_csv(input_path) + agg = aggregate_wandb_export_by_seed(df, summary_metric_columns=summary_metric_columns) + report = build_seed_bucket_report(agg) + if required_n_seeds is not None: + agg = filter_aggregated_to_required_n_seeds(agg, required_n_seeds) + agg = agg.fillna("") + out_p = Path(output_path) + out_p.parent.mkdir(parents=True, exist_ok=True) + agg.to_csv(out_p, index=False) + return agg, report + + +def _union_column_order(frames: list[pd.DataFrame]) -> list[str]: + """Stable union of column names in first-seen order (for concat alignment).""" + order: list[str] = [] + seen: set[str] = set() + for f in frames: + for c in f.columns: + if c not in seen: + seen.add(c) + order.append(c) + return order + + +def aggregate_many_wandb_export_csvs( + input_paths: list[str | Path], + output_path: str | Path, + *, + summary_metric_columns: list[str] | None = None, + required_n_seeds: int | None = None, +) -> tuple[pd.DataFrame, pd.DataFrame]: + """ + Load multiple per-run export CSVs (e.g. loader shards), aggregate each by + seed, concatenate rows, optionally filter to an exact ``n_seeds``, and + write one combined seed-aggregated CSV. + + Shards should partition runs (e.g. one file per model or per dataset) so + hyperparameter groups are not duplicated across files. + + The seed bucket report is computed on the **concatenated** unfiltered + aggregate (same keys as a single monolithic export). + + Returns ``(written_frame, seed_bucket_report)``. + """ + paths = [Path(p) for p in input_paths] + if not paths: + raise ValueError("aggregate_many_wandb_export_csvs: no input paths") + + frames: list[pd.DataFrame] = [] + for p in paths: + df = load_wandb_export_csv(p) + frames.append(aggregate_wandb_export_by_seed(df, summary_metric_columns=summary_metric_columns)) + + cols = _union_column_order(frames) + out = pd.concat(frames, ignore_index=True, sort=False) + out = out.reindex(columns=cols) + report = build_seed_bucket_report(out) + if required_n_seeds is not None: + out = filter_aggregated_to_required_n_seeds(out, required_n_seeds) + out = out.fillna("") + out_p = Path(output_path) + out_p.parent.mkdir(parents=True, exist_ok=True) + out.to_csv(out_p, index=False) + return out, report + + +# ----------------------------------------------------------------------------- +# Collapse seed-aggregated CSV: best hyperparams per (model, dataset, ...) +# ----------------------------------------------------------------------------- + +# Metric name (last path segment, lowercased) -> "max" or "min" for val split selection. +MONITOR_METRIC_OPTIMIZATION: dict[str, str] = { + "accuracy": "max", + "auroc": "max", + "roc_auc": "max", + "f1": "max", + "precision": "max", + "recall": "max", + "mae": "min", + "mse": "min", + "rmse": "min", + "loss": "min", +} + +DEFAULT_COLLAPSE_GROUP_COLS: list[str] = ["model", "dataset"] + + +def metric_name_tail(monitor_raw: str) -> str: + """Normalize ``dataset.parameters.monitor_metric`` to a W&B metric suffix (e.g. ``accuracy``).""" + m = str(monitor_raw).strip() + if not m or m.lower() in ("nan", "none"): + return "" + if "/" in m: + return m.rsplit("/", 1)[-1].strip().lower() + return m.lower() + + +def safe_metric_col_token(tail: str) -> str: + """Safe fragment for CSV column names such as ``train_accuracy_mean``.""" + t = re.sub(r"[^\w]+", "_", tail.strip().lower()).strip("_") + return t or "unknown" + + +def optimization_mode_for_metric_tail(tail: str) -> str: + mode = MONITOR_METRIC_OPTIMIZATION.get(tail.strip().lower(), "max") + return mode if mode in ("max", "min") else "max" + + +def _first_existing_column(candidates: list[str], available: set[str]) -> str | None: + for c in candidates: + if c in available: + return c + return None + + +def _paired_std_from_mean(mean_col: str | None, available: set[str]) -> str | None: + """``summary_*__mean`` -> matching ``summary_*__std`` if present in the frame.""" + if not mean_col or not str(mean_col).endswith("__mean"): + return None + s = str(mean_col) + std_col = s[: -len("__mean")] + "__std" + return std_col if std_col in available else None + + +def _val_mean_columns_for_tail(tail: str) -> list[str]: + return [ + f"{SUMMARY_COLUMN_PREFIX}val/{tail}__mean", + f"{SUMMARY_COLUMN_PREFIX}best_epoch/val/{tail}__mean", + f"{SUMMARY_COLUMN_PREFIX}val_best_rerun/{tail}__mean", + ] + + +def _train_mean_columns_for_tail(tail: str) -> list[str]: + return [ + f"{SUMMARY_COLUMN_PREFIX}train/{tail}__mean", + f"{SUMMARY_COLUMN_PREFIX}best_epoch/train/{tail}__mean", + ] + + +def _test_mean_columns_for_tail(tail: str) -> list[str]: + return [f"{SUMMARY_COLUMN_PREFIX}test_best_rerun/{tail}__mean"] + + +def iter_best_val_group_picks( + df: pd.DataFrame, + *, + group_cols: list[str] | None = None, + monitor_column: str = MONITOR_METRIC_COLUMN, +): + """ + For each ``group_cols`` group, pick the row index with best validation mean + (same rule as ``collapse_aggregated_wandb_by_best_val``). + + Yields ``(group_key_tuple, pick_idx, monitor_val, tail)``. + """ + if group_cols is None: + group_cols = list(DEFAULT_COLLAPSE_GROUP_COLS) + + missing_g = [c for c in group_cols if c not in df.columns] + if missing_g: + raise KeyError(f"collapse: missing group column(s): {missing_g}") + if monitor_column not in df.columns: + raise KeyError(f"collapse: missing {monitor_column!r} (re-run loader / aggregate).") + + work = df + colset = set(work.columns) + + for _gk, sub in work.groupby(group_cols, dropna=False): + keys = _gk if isinstance(_gk, tuple) else (_gk,) + if len(keys) != len(group_cols): + raise RuntimeError("groupby key length mismatch") + + mon_series = ( + sub[monitor_column] + .dropna() + .astype(str) + .str.strip() + .replace({"nan": "", "NaN": ""}) + ) + mon_series = mon_series[mon_series != ""] + monitor_val = mon_series.iloc[0] if len(mon_series) else "" + + tail = metric_name_tail(monitor_val) + + pick_idx = sub.index[0] + val_src = _first_existing_column(_val_mean_columns_for_tail(tail), colset) if tail else None + if val_src is not None: + scores = pd.to_numeric(sub[val_src], errors="coerce") + if scores.notna().any(): + mode = optimization_mode_for_metric_tail(tail) + pick_idx = scores.idxmax() if mode == "max" else scores.idxmin() + + yield keys, pick_idx, monitor_val, tail + + +def aggregated_rows_best_validation_per_group( + df: pd.DataFrame, + *, + group_cols: list[str] | None = None, + monitor_column: str = MONITOR_METRIC_COLUMN, +) -> pd.DataFrame: + """ + Full **seed-aggregated** rows for the best validation setting in each group + (same picks as collapse / leaderboard), including all config columns. + """ + work = df.copy() + picked: list[pd.Series] = [] + for _keys, pick_idx, _monitor_val, _tail in iter_best_val_group_picks( + work, group_cols=group_cols, monitor_column=monitor_column + ): + picked.append(work.loc[pick_idx]) + if not picked: + return pd.DataFrame() + return pd.DataFrame(picked).reset_index(drop=True) + + +def _serialize_hydra_cli_value(val: Any) -> str | None: + if val is None: + return None + if isinstance(val, float) and pd.isna(val): + return None + s = str(val).replace("\r", "").strip() + if s == "" or s.lower() in {"nan", "none", ""}: + return None + return s + + +def normalize_json_list_string_for_hydra_cli(s: str) -> str: + """ + If ``s`` is a JSON array, return a Hydra-style bracket list ``[a,b,c]`` (no spaces + after commas): string elements as in ``gat.sh`` / ``hopse_m.sh``; integer elements + as in ``sann.sh`` (``model.feature_encoder.selected_dimensions``). + + Otherwise return ``s`` unchanged (already ``[a,b]``, not JSON, or invalid). + """ + t = s.replace("\r", "").strip() + if len(t) < 2 or t[0] != "[": + return s + try: + parsed = json.loads(t) + except json.JSONDecodeError: + return s + if not isinstance(parsed, list) or not parsed: + return s + if all(isinstance(x, str) and x for x in parsed): + return "[" + ",".join(parsed) + "]" + if all(isinstance(x, bool) for x in parsed): + return s + if all(isinstance(x, int) for x in parsed): + return "[" + ",".join(str(x) for x in parsed) + "]" + if all(isinstance(x, (int, float)) for x in parsed): + try: + ints: list[int] = [] + for x in parsed: + xf = float(x) + if not xf.is_integer(): + return s + ints.append(int(xf)) + return "[" + ",".join(str(x) for x in ints) + "]" + except (TypeError, ValueError): + return s + return s + + +def _coerce_whole_number_override(key: str, s: str) -> str: + """Emit 1 instead of 1.0 for keys that must be integers in YAML / native code.""" + if key not in HYDRA_WHOLE_NUMBER_OVERRIDE_KEYS or not s: + return s + try: + x = float(s) + except ValueError: + return s + if x.is_integer(): + return str(int(x)) + return s + + +def hydra_overrides_from_aggregated_row( + row: Any, + *, + config_keys: list[str] | None = None, + skip_keys: set[str] | None = None, +) -> list[str]: + """ + Build ``key=value`` strings for ``python -m topobench`` from a loader-style + config column set (``CONFIG_PARAM_KEYS``). Skips empty / NaN cells. + """ + if config_keys is None: + config_keys = list(CONFIG_PARAM_KEYS) + skip = skip_keys or set() + out: list[str] = [] + for key in config_keys: + if key in skip: + continue + if key not in row: + continue + s = _serialize_hydra_cli_value(row.get(key)) + if s is None: + continue + if key == "dataset": + s = hydra_dataset_key_from_loader_identity(s) + if key in HYDRA_JSON_LIST_TO_BRACKET_KEYS: + s = normalize_json_list_string_for_hydra_cli(s) + s = _coerce_whole_number_override(key, s) + out.append(f"{key}={s}") + return out + + +def collapse_aggregated_wandb_by_best_val( + df: pd.DataFrame, + *, + group_cols: list[str] | None = None, + monitor_column: str = MONITOR_METRIC_COLUMN, +) -> pd.DataFrame: + """ + From a **seed-aggregated** export (``...__mean`` / ``...__std`` columns), keep one + row per ``group_cols`` by picking the hyperparameter row with the best **validation** + mean for the dataset's monitored metric. + + Output columns: ``group_cols``, ``monitor_column``, then a sparse wide block + ``train__mean``, ``train__std``, ``val__mean``, + ``val__std``, ``test__mean``, ``test__std`` for every + metric tail that appears anywhere in ``monitor_column``; only the block matching + that row's monitor is filled, others are empty. Std values come from the paired + ``summary_*__std`` columns of the winning row. + """ + if group_cols is None: + group_cols = list(DEFAULT_COLLAPSE_GROUP_COLS) + + missing_g = [c for c in group_cols if c not in df.columns] + if missing_g: + raise KeyError(f"collapse: missing group column(s): {missing_g}") + if monitor_column not in df.columns: + raise KeyError(f"collapse: missing {monitor_column!r} (re-run loader / aggregate).") + + work = df.copy() + colset = set(work.columns) + + tails_seen: set[str] = set() + for v in work[monitor_column].fillna("").astype(str): + t = metric_name_tail(v) + if t: + tails_seen.add(t) + + tokens_sorted = sorted({safe_metric_col_token(t) for t in tails_seen}) + metric_block_cols: list[str] = [] + for tok in tokens_sorted: + metric_block_cols.extend( + [ + f"train_{tok}_mean", + f"train_{tok}_std", + f"val_{tok}_mean", + f"val_{tok}_std", + f"test_{tok}_mean", + f"test_{tok}_std", + ] + ) + + out_rows: list[dict[str, Any]] = [] + + for keys, pick_idx, monitor_val, tail in iter_best_val_group_picks( + work, group_cols=group_cols, monitor_column=monitor_column + ): + base_row = dict(zip(group_cols, keys, strict=True)) + tok = safe_metric_col_token(tail) if tail else "unknown" + + base_row[monitor_column] = monitor_val + + for c in metric_block_cols: + base_row[c] = "" + + winner = work.loc[pick_idx] + + if tail: + train_src = _first_existing_column(_train_mean_columns_for_tail(tail), colset) + val_src_w = _first_existing_column(_val_mean_columns_for_tail(tail), colset) + test_src = _first_existing_column(_test_mean_columns_for_tail(tail), colset) + + if train_src: + base_row[f"train_{tok}_mean"] = winner.get(train_src, "") + tr_std = _paired_std_from_mean(train_src, colset) + if tr_std: + base_row[f"train_{tok}_std"] = winner.get(tr_std, "") + if val_src_w: + base_row[f"val_{tok}_mean"] = winner.get(val_src_w, "") + va_std = _paired_std_from_mean(val_src_w, colset) + if va_std: + base_row[f"val_{tok}_std"] = winner.get(va_std, "") + if test_src: + base_row[f"test_{tok}_mean"] = winner.get(test_src, "") + te_std = _paired_std_from_mean(test_src, colset) + if te_std: + base_row[f"test_{tok}_std"] = winner.get(te_std, "") + + out_rows.append(base_row) + + out = pd.DataFrame(out_rows) + ordered = list(group_cols) + [monitor_column] + metric_block_cols + out = out[[c for c in ordered if c in out.columns]] + return out.fillna("") + + +def collapse_aggregated_wandb_csv( + input_path: str | Path, + output_path: str | Path, + *, + group_cols: list[str] | None = None, + monitor_column: str = MONITOR_METRIC_COLUMN, +) -> pd.DataFrame: + """Load seed-aggregated CSV, collapse to best val per group, write CSV.""" + df = load_wandb_export_csv(input_path) + collapsed = collapse_aggregated_wandb_by_best_val( + df, group_cols=group_cols, monitor_column=monitor_column + ) + out_p = Path(output_path) + out_p.parent.mkdir(parents=True, exist_ok=True) + collapsed.to_csv(out_p, index=False) + return collapsed + + +# ----------------------------------------------------------------------------- +# Hyperparameter sensitivity (seed-aggregated CSV, group by model) +# ----------------------------------------------------------------------------- + + +def hyperparam_axis_columns(df: pd.DataFrame) -> list[str]: + """ + Config columns to treat as hyperparameters for sensitivity plots. + + Uses ``CONFIG_PARAM_KEYS`` present in ``df``, excluding ``model`` (group key) + and the data-seed column (not present after seed aggregation). + """ + out: list[str] = [] + for c in CONFIG_PARAM_KEYS: + if c == "model": + continue + if c == SEED_COLUMN: + continue + if c in df.columns: + out.append(c) + return out + + +def _nonempty_str_nunique(series: pd.Series) -> int: + t = series.astype(str).str.strip() + t = t.mask(t.isin({"", "nan", "None", "NaN", ""})) + return int(t.nunique(dropna=True)) + + +def varied_hyperparam_columns( + df: pd.DataFrame, + *, + candidate_cols: list[str] | None = None, +) -> list[str]: + """Columns among ``candidate_cols`` with more than one distinct non-empty value.""" + if candidate_cols is None: + candidate_cols = hyperparam_axis_columns(df) + varied: list[str] = [] + for c in candidate_cols: + if c not in df.columns: + continue + if _nonempty_str_nunique(df[c]) > 1: + varied.append(c) + return varied + + +def val_metric_mean_per_row( + df: pd.DataFrame, + *, + monitor_column: str = MONITOR_METRIC_COLUMN, +) -> pd.Series: + """ + For each row, validation **mean** (seed-aggregated) for that row's + ``dataset.parameters.monitor_metric``, using the same column resolution + order as ``collapse_aggregated_wandb_by_best_val``. + """ + colset = set(df.columns) + if monitor_column not in df.columns: + return pd.Series(float("nan"), index=df.index, dtype="float64") + + def _one(row: pd.Series) -> float: + tail = metric_name_tail(str(row.get(monitor_column, ""))) + if not tail: + return float("nan") + src = _first_existing_column(_val_mean_columns_for_tail(tail), colset) + if not src: + return float("nan") + v = pd.to_numeric(row.get(src, float("nan")), errors="coerce") + return float(v) if pd.notna(v) else float("nan") + + return df.apply(_one, axis=1) + + +def infer_hyperparam_plot_kind( + series: pd.Series, + *, + min_scatter_unique: int = 8, + min_numeric_frac: float = 0.78, + max_bar_categories: int = 48, +) -> tuple[Literal["scatter", "bar", "skip"], pd.Series]: + """ + Decide scatter (continuous) vs bar (categorical / low cardinality). + + Returns ``(kind, x_values)`` where for ``scatter``, ``x_values`` is numeric; + for ``bar``, ``x_values`` is string category labels; for ``skip``, too many + categories for a readable bar chart. + """ + s = series.copy() + num = pd.to_numeric(s, errors="coerce") + n = len(s) + if n == 0: + return "skip", s + frac_num = float(num.notna().sum()) / float(n) + n_u_num = int(num.dropna().nunique()) + + if frac_num >= min_numeric_frac and n_u_num >= min_scatter_unique: + return "scatter", num + + lab = s.astype(str).str.strip() + lab = lab.replace({"": "«empty»", "nan": "«empty»", "None": "«empty»", "NaN": "«empty»"}) + n_u_lab = int(lab.nunique(dropna=False)) + if n_u_lab > max_bar_categories: + return "skip", lab + return "bar", lab + + +def safe_filename_token(name: str, *, max_len: int = 80) -> str: + """Filesystem-safe fragment from a column name or model id.""" + t = re.sub(r"[^\w.\-]+", "_", str(name).strip()).strip("_") + if not t: + t = "unknown" + return t[:max_len] diff --git a/scripts/sann.sh b/scripts/sann.sh new file mode 100755 index 000000000..cc5627657 --- /dev/null +++ b/scripts/sann.sh @@ -0,0 +1,509 @@ +#!/bin/bash +# ============================================================================== +# SCRIPT: sann.sh +# DESCRIPTION: +# Runs a scalable hyperparameter sweep for SANN models across both +# simplicial and cellular domains. +# - ARCHITECTURE: Uses a "Cartesian Product" generation strategy. +# - CONCURRENCY: Uses "Virtual Slots" to run N jobs per GPU. +# - ORDERING: Prioritizes running all seeds for a config before moving on. +# - FILTERING: Skips invalid model+dataset combos (cell + simplicial data). +# ============================================================================== +# DO NOT MISS THIS + +export SELECTED_GPUS="0,1,2,3,4,5,6,7" +wandb_entity="gbg141-hopse" +RESUME=true + # Set to true to skip already-completed runs (reads SUCCESSFUL_RUNS.log) + +# ============================================================================== +# SECTION 1: LOGGING & ENVIRONMENT SETUP +# ============================================================================== + +# 1.1 Define Project Identifiers +script_name="$(basename "${BASH_SOURCE[0]}" .sh)" +project_name="${script_name}" +log_group="sann_sweep" +LOG_DIR="./logs/${log_group}" + +echo "==========================================================" +echo " Preparing log directory: $LOG_DIR" +echo "==========================================================" + +# 1.2 Log directory management +if [[ "$RESUME" == "true" ]]; then + echo "⏩ RESUME MODE: Keeping existing logs." + mkdir -p "$LOG_DIR" +else + if [ -d "$LOG_DIR" ]; then rm -r "$LOG_DIR"; fi + mkdir -p "$LOG_DIR" +fi + +# 1.3 Robust Dependency Loading +SCRIPT_DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" &> /dev/null && pwd )" +export HYDRA_FULL_ERROR=1 + +find_logging_script() { + local dir="$1" + while [[ "$dir" != "/" ]]; do + if [[ -f "$dir/base/logging.sh" ]]; then echo "$dir/base/logging.sh"; return 0; fi + if [[ -f "$dir/scripts/base/logging.sh" ]]; then echo "$dir/scripts/base/logging.sh"; return 0; fi + dir="$(dirname "$dir")" + done + return 1 +} + +LOGGING_PATH=$(find_logging_script "$SCRIPT_DIR") +if [[ -n "$LOGGING_PATH" ]]; then + echo "✔ Found logging utils at: $LOGGING_PATH" + source "$LOGGING_PATH" +else + echo "❌ CRITICAL ERROR: Could not locate 'base/logging.sh'." + exit 1 +fi + +# ============================================================================== +# CPU THREAD LIMITS (Crucial for concurrency) +# ============================================================================== +export OMP_NUM_THREADS=1 +export MKL_NUM_THREADS=1 +export OPENBLAS_NUM_THREADS=1 +export VECLIB_MAXIMUM_THREADS=1 +export NUMEXPR_NUM_THREADS=1 +# ============================================================================== +# SECTION 2: HARDWARE & CONCURRENCY (Auto-Detected) +# ============================================================================== + +# 2.1 Auto-detect GPUs and determine jobs-per-GPU from VRAM. +# Output format: "JOBS_PER_GPU gpu_id_0 gpu_id_1 ..." +# Thresholds: >= 80 GB -> 4 jobs, <= 30 GB -> 2 jobs, between -> 3 jobs. +_gpu_info=$(python3 -c " +import subprocess +import os + +# 1. Read the allowed GPUs from the environment variable +selected_env = os.environ.get('SELECTED_GPUS', '').strip() +allowed_gpus = [x.strip() for x in selected_env.split(',')] if selected_env else None + +try: + out = subprocess.check_output( + ['nvidia-smi', '--query-gpu=index,memory.total', '--format=csv,noheader,nounits'], + text=True + ) + indices, mem_mb = [], [] + for line in out.strip().splitlines(): + idx, mem = line.split(',') + idx = idx.strip() + + # 2. Skip this GPU if it's not in our selected list + if allowed_gpus and idx not in allowed_gpus: + continue + + indices.append(idx) + mem_mb.append(int(mem.strip())) + + # Safety check in case the selected GPUs don't exist + if not indices: + print('0') + exit(0) + + min_mem_gb = min(mem_mb) / 1024 + if min_mem_gb >= 80: + jobs = 5 + elif min_mem_gb <= 10: + jobs = 1 + elif min_mem_gb <= 30: + jobs = 2 + else: + jobs = 3 + + print(jobs, ' '.join(indices)) +except Exception: + print('2 0') +") +read -r JOBS_PER_GPU _gpu_ids <<< "$_gpu_info" +read -ra physical_gpus <<< "$_gpu_ids" + +echo "✔ Detected ${#physical_gpus[@]} GPU(s): ${physical_gpus[*]}" +echo "✔ Jobs per GPU: $JOBS_PER_GPU" + +# 2.2 Create Virtual Slots +gpus=() +for gpu in "${physical_gpus[@]}"; do + for ((i=1; i<=JOBS_PER_GPU; i++)); do gpus+=("$gpu"); done +done +echo "✔ Total virtual slots: ${#gpus[@]}" + +# 2.3 Initialize Slot Tracking +declare -a slot_pids +for i in "${!gpus[@]}"; do slot_pids[$i]=0; done + + +# ============================================================================== +# SECTION 3: EXPERIMENT PARAMETERS +# ============================================================================== + +# --- Models (both domains) --- +# Use "alias::hydra_value" to disambiguate run names (both share basename "sann"). +models=( + "sim_sann::simplicial/sann" +) + +# --- Datasets --- +datasets=( + "graph/MUTAG" + "graph/PROTEINS" + "graph/NCI1" + "graph/NCI109" + "graph/BBB_Martins" + "graph/Caco2_Wang" + "graph/Clearance_Hepatocyte_AZ" + "graph/CYP3A4_Veith" + "simplicial/mantra_name" + "simplicial/mantra_orientation" + "simplicial/mantra_betti_numbers" + "graph/cocitation_cora" + "graph/cocitation_citeseer" + "graph/cocitation_pubmed" + "graph/ZINC" +) + +# --- Max Hops (SANN-specific, controls k-hop feature precomputation) --- +max_hops=(1 2 3) + +# --- Hyperparameters --- +num_layers=(1 2 4) +hidden_channels=(128 256) +proj_dropouts=(0.25 0.5) +lrs=(0.01 0.001) +weight_decays=(0.0001) +batch_sizes=(128 256) +DATA_SEEDS=(0 3 5 7 9) + +# --- Fixed Parameters --- +FIXED_ARGS=( + "trainer.max_epochs=500" + "trainer.min_epochs=50" + "trainer.check_val_every_n_epoch=5" + "callbacks.early_stopping.patience=10" + "delete_checkpoint_after_test=True" +) + + +# ============================================================================== +# SECTION 4: SWEEP CONFIGURATION MAPPING (CRITICAL ORDERING) +# Format: "ShortTag | HydraKey | ${Array[*]}" +# +# Values support an optional "alias::hydra_value" syntax for readable names. +# The generator also filters out invalid model+dataset combos. +# +# Key differences from hopse_m: +# - No neighborhoods or encodings (SANN uses k-hop feature precomputation) +# - max_hop: transforms.sann_encoding.max_hop (SANN-specific) +# - Backbone layers: model.backbone.n_layers (same as hopse_m) +# ============================================================================== + +SWEEP_CONFIG=( + # --- LEVEL 1: SLOWEST CHANGING (Outer Loops) --- + "|model|${models[*]}" + "|dataset|${datasets[*]}" + "mh|transforms.sann_encoding.max_hop|${max_hops[*]}" + + # --- LEVEL 2: HYPERPARAMETERS --- + "L|model.backbone.n_layers|${num_layers[*]}" + "h|model.feature_encoder.out_channels|${hidden_channels[*]}" + "pdro|model.feature_encoder.proj_dropout|${proj_dropouts[*]}" + "lr|optimizer.parameters.lr|${lrs[*]}" + "wd|optimizer.parameters.weight_decay|${weight_decays[*]}" + "bs|dataset.dataloader_params.batch_size|${batch_sizes[*]}" + + # --- LEVEL 3: FASTEST CHANGING (Inner Loop) --- + "seed|dataset.split_params.data_seed|${DATA_SEEDS[*]}" +) + + +# ============================================================================== +# SECTION 5: PYTHON GENERATOR (Smart Transductive Filtering) +# ============================================================================== + +# Define where your dataset YAMLs live so the generator can inspect them. +# UPDATE THIS PATH IF YOUR CONFIGS ARE STORED ELSEWHERE. +export CONFIG_DIR="./configs/dataset" + +generate_combinations() { +python3 -c " +import sys, itertools, os + +config_dir = os.environ.get('CONFIG_DIR', './configs/dataset') + +# 1. Parse Input Specs +specs = [] +for item in sys.argv[1:]: + parts = item.split('|') + tag = parts[0].strip() + key = parts[1].strip() + vals = parts[2].split() + specs.append({'tag': tag, 'key': key, 'vals': vals}) + +# 2. Generate Cartesian Product +options = [[(s['tag'], s['key'], val) for val in s['vals']] for s in specs] +combinations = list(itertools.product(*options)) + +# Helper to strip alias +def hydra_val(v): + return v.split('::', 1)[1] if '::' in v else v + +# Find the first batch size in the sweep so we don't duplicate transductive runs +bs_key = 'dataset.dataloader_params.batch_size' +bs_spec = next((s for s in specs if s['key'] == bs_key), None) +first_bs = hydra_val(bs_spec['vals'][0]) if bs_spec else None + +# 3. Filter and Mutate Combos +valid = [] +skipped = 0 +transductive_cache = {} + +for combo in combinations: + vals_dict = {key: hydra_val(val) for (_, key, val) in combo} + model_val = vals_dict.get('model', '') + dataset_val = vals_dict.get('dataset', '') + current_bs = vals_dict.get(bs_key, '') + + # --- Rule A: Skip cell model + simplicial dataset --- + if model_val.startswith('cell/') and dataset_val.startswith('simplicial/'): + skipped += 1 + continue + + # --- Rule B: Transductive Batch Size Handler --- + is_transductive = False + if dataset_val in transductive_cache: + is_transductive = transductive_cache[dataset_val] + else: + # Construct path to yaml (e.g., ./configs/dataset/graph/cocitation_cora.yaml) + yaml_path = os.path.join(config_dir, f'{dataset_val}.yaml') + if os.path.exists(yaml_path): + with open(yaml_path, 'r') as f: + # Fast text check avoids needing pip install pyyaml + if 'learning_setting: transductive' in f.read(): + is_transductive = True + else: + print(f'⚠️ WARNING: Could not find config at {yaml_path}', file=sys.stderr) + + transductive_cache[dataset_val] = is_transductive + + if is_transductive: + # If this isn't the first batch size in the sweep list, skip it + # to avoid running the exact same bs=1 experiment multiple times. + if current_bs != first_bs: + skipped += 1 + continue + + # Mutate the current combination to force batch_size to 1 + new_combo = [] + for (tag, key, val) in combo: + if key == bs_key: + # Force the value to 1. If an alias was used, keep it clean. + new_combo.append((tag, key, '1')) + else: + new_combo.append((tag, key, val)) + combo = tuple(new_combo) + + valid.append(combo) + +# --- Helper: Resolve complex_dim per dataset --- +# The PrecomputeKHopFeatures transform needs complex_dim to match the number +# of incidence matrices. For graph datasets, this is determined by the clique +# lifting's complex_dim (= max_dim_if_lifted, default 2). For simplicial +# datasets, we read manifold_dim from the YAML (default 2). +complex_dim_cache = {} +def get_complex_dim(dataset_val): + if dataset_val in complex_dim_cache: + return complex_dim_cache[dataset_val] + yaml_path = os.path.join(config_dir, f'{dataset_val}.yaml') + dim = 2 # safe default (clique lifting default) + if os.path.exists(yaml_path): + with open(yaml_path, 'r') as f: + import re + content = f.read() + # Check for max_dim_if_lifted (graph datasets) + m = re.search(r'max_dim_if_lifted:\s*(\d+)', content) + if m: + dim = int(m.group(1)) + elif dataset_val.startswith('simplicial/'): + # For simplicial datasets, use manifold_dim (default 2) + m2 = re.search(r'manifold_dim:\s*(\d+)', content) + if m2: + dim = int(m2.group(1)) + complex_dim_cache[dataset_val] = dim + return dim + +# 4. Print header +print(f'TOTAL;{len(valid)}') +if skipped: + print(f'SKIPPED;{skipped}', file=sys.stderr) + +# 5. Print each valid combination +for combo in valid: + name_parts = [] + cmd_args = [] + dataset_val = '' + for (tag, key, val) in combo: + if '::' in val: + alias, hydra_val_str = val.split('::', 1) + clean_val = alias + actual_val = hydra_val_str + else: + clean_val = os.path.basename(val) + actual_val = val + + if key == 'dataset': + dataset_val = actual_val + + if tag: + name_parts.append(f'{tag}{clean_val}') + else: + name_parts.append(clean_val) + cmd_args.append(f'{key}={actual_val}') + + # Append complex_dim, max_rank and selected_dimensions based on dataset. + # complex_dim=K means the transform reads incidence_1..incidence_K, producing + # features for ranks 0..K-1 (K ranks total). The model's selected_dimensions + # and backbone/readout complex_dim (max_rank) must all agree on K. + cdim = get_complex_dim(dataset_val) + cmd_args.append(f'transforms.sann_encoding.complex_dim={cdim}') + cmd_args.append(f'transforms.sann_encoding.max_rank={cdim - 1}') + sel_dims = ','.join(str(x) for x in range(cdim)) + cmd_args.append(f'model.feature_encoder.selected_dimensions=[{sel_dims}]') + cmd_args.append(f'++model.backbone.complex_dim={cdim}') + + run_name = '_'.join(name_parts) + print(f'{run_name};' + ' '.join(cmd_args)) +" "${SWEEP_CONFIG[@]}" +} + +# ============================================================================== +# SECTION 5.5: RESUME — LOAD COMPLETED RUNS +# ============================================================================== + +declare -A _completed_runs +if [[ "$RESUME" == "true" ]]; then + # run_and_log nests: $LOG_DIR/$log_group/SUCCESSFUL_RUNS.log + _success_log="$LOG_DIR/$log_group/SUCCESSFUL_RUNS.log" + if [[ -f "$_success_log" ]]; then + while IFS= read -r _line; do + # Format: "DATE: [SUCCESS] run_name" + _rname="${_line##*\[SUCCESS\] }" + _completed_runs["$_rname"]=1 + done < "$_success_log" + echo "✔ Loaded ${#_completed_runs[@]} completed runs to skip." + else + echo "⚠️ No SUCCESSFUL_RUNS.log found at $_success_log — nothing to skip." + fi +fi + +# ============================================================================== +# SECTION 6: MAIN EXECUTION LOOP +# ============================================================================== + +echo "----------------------------------------------------------" +echo " Generating experiment combinations..." +echo "----------------------------------------------------------" + +total_runs=0 +run_counter=0 +skipped_completed=0 +one_percent_step=1 + +while IFS=";" read -r col1 col2; do + + # 6.1 Handle Header + if [[ "$col1" == "TOTAL" ]]; then + total_runs=$col2 + if [ "$total_runs" -gt 0 ]; then + one_percent_step=$(( total_runs / 100 )) + fi + if [ "$one_percent_step" -eq 0 ]; then one_percent_step=1; fi + + echo "► Total runs planned: $total_runs" + echo "► Reporting progress every $one_percent_step runs (1%)" + echo "----------------------------------------------------------" + continue + fi + + # 6.2 Parse Run Data + run_name="$col1" + dynamic_args_str="$col2" + + # 6.2.1 Skip if already completed (RESUME mode) + if [[ "$RESUME" == "true" && -n "${_completed_runs[$run_name]+x}" ]]; then + ((skipped_completed++)) + continue + fi + + # 6.3 Update Progress + ((run_counter++)) + if (( run_counter % one_percent_step == 0 )); then + if [ "$total_runs" -gt 0 ]; then + percent=$(( (run_counter * 100) / total_runs )) + else + percent=0 + fi + echo "📊 Progress: ${percent}% completed ($run_counter / $total_runs runs launched)" + fi + + # 6.4 Find a Free GPU Slot + assigned_slot=-1 + while [ "$assigned_slot" -eq -1 ]; do + for i in "${!gpus[@]}"; do + pid="${slot_pids[$i]}" + if [ "$pid" -eq 0 ] || ! kill -0 "$pid" 2>/dev/null; then + assigned_slot=$i + break + fi + done + if [ "$assigned_slot" -eq -1 ]; then + wait -n + fi + done + + # 6.5 Prepare Command + current_gpu=${gpus[$assigned_slot]} + read -ra DYNAMIC_ARGS_ARRAY <<< "$dynamic_args_str" + + # --- Extract dataset name for dynamic W&B project --- + dataset_val="" + for arg in "${DYNAMIC_ARGS_ARRAY[@]}"; do + if [[ $arg == dataset=* ]]; then + dataset_full_path="${arg#*=}" + dataset_val=$(basename "$dataset_full_path") + break + fi + done + dynamic_project_name="${project_name}_${dataset_val}" + + cmd=( + "python" "-m" "topobench" + "${DYNAMIC_ARGS_ARRAY[@]}" + "${FIXED_ARGS[@]}" + "trainer.devices=[${current_gpu}]" + "+logger.wandb.entity=${wandb_entity}" + "logger.wandb.project=${dynamic_project_name}" + "+logger.wandb.name=${run_name}" + ) + + # 6.6 Execute + run_and_log "${cmd[*]}" "$log_group" "$run_name" "$LOG_DIR" & + slot_pids[$assigned_slot]=$! + +done < <(generate_combinations) + + +# ============================================================================== +# SECTION 7: CLEANUP +# ============================================================================== +echo "----------------------------------------------------------" +echo " All jobs launched ($run_counter total, $skipped_completed skipped as already completed)." +echo " Waiting for remaining background jobs to finish..." +echo "----------------------------------------------------------" +wait +echo "✔ All runs complete." diff --git a/scripts/sccnn.sh b/scripts/sccnn.sh new file mode 100644 index 000000000..13324b669 --- /dev/null +++ b/scripts/sccnn.sh @@ -0,0 +1,428 @@ +#!/bin/bash +# ============================================================================== +# SCRIPT: sccnn.sh +# DESCRIPTION: +# Runs a scalable hyperparameter sweep for SCCNN (Simplicial Complex Convolutional +# Neural Network) across both graph datasets (lifted to simplicial complexes) +# and native simplicial datasets (mantra). +# - ARCHITECTURE: Uses a "Cartesian Product" generation strategy. +# - CONCURRENCY: Uses "Virtual Slots" to run N jobs per GPU. +# - ORDERING: Prioritizes running all seeds for a config before moving on. +# - FILTERING: Transductive datasets forced to batch_size=1. +# ============================================================================== + +export SELECTED_GPUS="0,1,2,3,4,5,6,7" +wandb_entity="gbg141-hopse" +RESUME=true # Set to true to skip already-completed runs (reads SUCCESSFUL_RUNS.log) + +# ============================================================================== +# SECTION 1: LOGGING & ENVIRONMENT SETUP +# ============================================================================== + +# 1.1 Define Project Identifiers +script_name="$(basename "${BASH_SOURCE[0]}" .sh)" +project_name="${script_name}" +log_group="sccnn_sweep" +LOG_DIR="./logs/${log_group}" + +echo "==========================================================" +echo " Preparing log directory: $LOG_DIR" +echo "==========================================================" + +# 1.2 Log directory management +if [[ "$RESUME" == "true" ]]; then + echo "⏩ RESUME MODE: Keeping existing logs." + mkdir -p "$LOG_DIR" +else + if [ -d "$LOG_DIR" ]; then rm -r "$LOG_DIR"; fi + mkdir -p "$LOG_DIR" +fi + +# 1.3 Robust Dependency Loading +SCRIPT_DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" &> /dev/null && pwd )" +export HYDRA_FULL_ERROR=1 + +find_logging_script() { + local dir="$1" + while [[ "$dir" != "/" ]]; do + if [[ -f "$dir/base/logging.sh" ]]; then echo "$dir/base/logging.sh"; return 0; fi + if [[ -f "$dir/scripts/base/logging.sh" ]]; then echo "$dir/scripts/base/logging.sh"; return 0; fi + dir="$(dirname "$dir")" + done + return 1 +} + +LOGGING_PATH=$(find_logging_script "$SCRIPT_DIR") +if [[ -n "$LOGGING_PATH" ]]; then + echo "✔ Found logging utils at: $LOGGING_PATH" + source "$LOGGING_PATH" +else + echo "❌ CRITICAL ERROR: Could not locate 'base/logging.sh'." + exit 1 +fi + +# ============================================================================== +# CPU THREAD LIMITS (Crucial for concurrency) +# ============================================================================== +export OMP_NUM_THREADS=1 +export MKL_NUM_THREADS=1 +export OPENBLAS_NUM_THREADS=1 +export VECLIB_MAXIMUM_THREADS=1 +export NUMEXPR_NUM_THREADS=1 + +# ============================================================================== +# SECTION 2: HARDWARE & CONCURRENCY (Auto-Detected) +# ============================================================================== + +# 2.1 Auto-detect GPUs and determine jobs-per-GPU from VRAM. +# Thresholds: >= 80 GB -> 5 jobs, <= 10 GB -> 1 job, <= 30 GB -> 2 jobs, else 3. +_gpu_info=$(python3 -c " +import subprocess +import os + +selected_env = os.environ.get('SELECTED_GPUS', '').strip() +allowed_gpus = [x.strip() for x in selected_env.split(',')] if selected_env else None + +try: + out = subprocess.check_output( + ['nvidia-smi', '--query-gpu=index,memory.total', '--format=csv,noheader,nounits'], + text=True + ) + indices, mem_mb = [], [] + for line in out.strip().splitlines(): + idx, mem = line.split(',') + idx = idx.strip() + if allowed_gpus and idx not in allowed_gpus: + continue + indices.append(idx) + mem_mb.append(int(mem.strip())) + if not indices: + print('0') + exit(0) + min_mem_gb = min(mem_mb) / 1024 + if min_mem_gb >= 80: + jobs = 5 + elif min_mem_gb <= 10: + jobs = 1 + elif min_mem_gb <= 30: + jobs = 2 + else: + jobs = 3 + print(jobs, ' '.join(indices)) +except Exception: + print('2 0') +") +read -r JOBS_PER_GPU _gpu_ids <<< "$_gpu_info" +read -ra physical_gpus <<< "$_gpu_ids" + +echo "✔ Detected ${#physical_gpus[@]} GPU(s): ${physical_gpus[*]}" +echo "✔ Jobs per GPU: $JOBS_PER_GPU" + +# 2.2 Create Virtual Slots +gpus=() +for gpu in "${physical_gpus[@]}"; do + for ((i=1; i<=JOBS_PER_GPU; i++)); do gpus+=("$gpu"); done +done +echo "✔ Total virtual slots: ${#gpus[@]}" + +# 2.3 Initialize Slot Tracking +declare -a slot_pids +for i in "${!gpus[@]}"; do slot_pids[$i]=0; done + + +# ============================================================================== +# SECTION 3: EXPERIMENT PARAMETERS +# ============================================================================== + +# --- Model --- +# sccnn_custom uses AllCellFeatureEncoder with selected_dimensions=[0,1,2]. +# Graph datasets are lifted to simplicial complexes via clique lifting (dim=2), +# producing ranks 0, 1, 2. Mantra datasets are native simplicial complexes. +models=( + "simplicial/sccnn_custom" +) + +# --- Datasets --- +datasets=( + "graph/MUTAG" + "graph/PROTEINS" + "graph/NCI1" + "graph/NCI109" + "graph/BBB_Martins" + "graph/Caco2_Wang" + "graph/Clearance_Hepatocyte_AZ" + "graph/CYP3A4_Veith" + "simplicial/mantra_name" + "simplicial/mantra_orientation" + "simplicial/mantra_betti_numbers" + "graph/cocitation_cora" + "graph/cocitation_citeseer" + "graph/cocitation_pubmed" + "graph/ZINC" +) + +# --- Hyperparameters --- +num_layers=(1 2 4) +hidden_channels=(128 256) +proj_dropouts=(0.0 0.25) +lrs=(0.01 0.001) +weight_decays=(0.0001) +batch_sizes=(128 256) +DATA_SEEDS=(0 3 5 7 9) + +# --- Fixed Parameters --- +FIXED_ARGS=( + "trainer.max_epochs=500" + "trainer.min_epochs=50" + "trainer.check_val_every_n_epoch=5" + "callbacks.early_stopping.patience=10" + "delete_checkpoint_after_test=True" +) + + +# ============================================================================== +# SECTION 4: SWEEP CONFIGURATION MAPPING +# Format: "ShortTag | HydraKey | ${Array[*]}" +# ============================================================================== + +SWEEP_CONFIG=( + # --- LEVEL 1: SLOWEST CHANGING (Outer Loops) --- + "|model|${models[*]}" + "|dataset|${datasets[*]}" + + # --- LEVEL 2: HYPERPARAMETERS --- + "L|model.backbone.n_layers|${num_layers[*]}" + "h|model.feature_encoder.out_channels|${hidden_channels[*]}" + "pdro|model.feature_encoder.proj_dropout|${proj_dropouts[*]}" + "lr|optimizer.parameters.lr|${lrs[*]}" + "wd|optimizer.parameters.weight_decay|${weight_decays[*]}" + "bs|dataset.dataloader_params.batch_size|${batch_sizes[*]}" + + # --- LEVEL 3: FASTEST CHANGING (Inner Loop) --- + "seed|dataset.split_params.data_seed|${DATA_SEEDS[*]}" +) + + +# ============================================================================== +# SECTION 5: PYTHON GENERATOR (Transductive Filtering) +# ============================================================================== + +export CONFIG_DIR="./configs/dataset" + +generate_combinations() { +python3 -c " +import sys, itertools, os + +config_dir = os.environ.get('CONFIG_DIR', './configs/dataset') + +# 1. Parse Input Specs +specs = [] +for item in sys.argv[1:]: + parts = item.split('|') + tag = parts[0].strip() + key = parts[1].strip() + vals = parts[2].split() + specs.append({'tag': tag, 'key': key, 'vals': vals}) + +# 2. Generate Cartesian Product +options = [[(s['tag'], s['key'], val) for val in s['vals']] for s in specs] +combinations = list(itertools.product(*options)) + +# Helper to strip alias +def hydra_val(v): + return v.split('::', 1)[1] if '::' in v else v + +# Find the first batch size to avoid duplicating transductive runs +bs_key = 'dataset.dataloader_params.batch_size' +bs_spec = next((s for s in specs if s['key'] == bs_key), None) +first_bs = hydra_val(bs_spec['vals'][0]) if bs_spec else None + +# 3. Filter and Mutate Combos +valid = [] +skipped = 0 +transductive_cache = {} + +for combo in combinations: + vals_dict = {key: hydra_val(val) for (_, key, val) in combo} + dataset_val = vals_dict.get('dataset', '') + current_bs = vals_dict.get(bs_key, '') + + # --- Transductive Batch Size Handler --- + if dataset_val in transductive_cache: + is_transductive = transductive_cache[dataset_val] + else: + is_transductive = False + yaml_path = os.path.join(config_dir, f'{dataset_val}.yaml') + if os.path.exists(yaml_path): + with open(yaml_path, 'r') as f: + if 'learning_setting: transductive' in f.read(): + is_transductive = True + else: + print(f'WARNING: Could not find config at {yaml_path}', file=sys.stderr) + transductive_cache[dataset_val] = is_transductive + + if is_transductive: + if current_bs != first_bs: + skipped += 1 + continue + new_combo = [] + for (tag, key, val) in combo: + if key == bs_key: + new_combo.append((tag, key, '1')) + else: + new_combo.append((tag, key, val)) + combo = tuple(new_combo) + + valid.append(combo) + +# 4. Print header +print(f'TOTAL;{len(valid)}') +if skipped: + print(f'SKIPPED;{skipped}', file=sys.stderr) + +# 5. Print each valid combination +for combo in valid: + name_parts = [] + cmd_args = [] + for (tag, key, val) in combo: + if '::' in val: + alias, actual_val = val.split('::', 1) + clean_val = alias + else: + clean_val = os.path.basename(val) + actual_val = val + if tag: + name_parts.append(f'{tag}{clean_val}') + else: + name_parts.append(clean_val) + cmd_args.append(f'{key}={actual_val}') + + run_name = '_'.join(name_parts) + print(f'{run_name};' + ' '.join(cmd_args)) +" "${SWEEP_CONFIG[@]}" +} + + +# ============================================================================== +# SECTION 5.5: RESUME — LOAD COMPLETED RUNS +# ============================================================================== + +declare -A _completed_runs +if [[ "$RESUME" == "true" ]]; then + _success_log="$LOG_DIR/$log_group/SUCCESSFUL_RUNS.log" + if [[ -f "$_success_log" ]]; then + while IFS= read -r _line; do + _rname="${_line##*\[SUCCESS\] }" + _completed_runs["$_rname"]=1 + done < "$_success_log" + echo "✔ Loaded ${#_completed_runs[@]} completed runs to skip." + else + echo "⚠️ No SUCCESSFUL_RUNS.log found at $_success_log — nothing to skip." + fi +fi + + +# ============================================================================== +# SECTION 6: MAIN EXECUTION LOOP +# ============================================================================== + +echo "----------------------------------------------------------" +echo " Generating experiment combinations..." +echo "----------------------------------------------------------" + +total_runs=0 +run_counter=0 +skipped_completed=0 +one_percent_step=1 + +while IFS=";" read -r col1 col2; do + + # 6.1 Handle Header + if [[ "$col1" == "TOTAL" ]]; then + total_runs=$col2 + if [ "$total_runs" -gt 0 ]; then + one_percent_step=$(( total_runs / 100 )) + fi + if [ "$one_percent_step" -eq 0 ]; then one_percent_step=1; fi + echo "► Total runs planned: $total_runs" + echo "► Reporting progress every $one_percent_step runs (1%)" + echo "----------------------------------------------------------" + continue + fi + + # 6.2 Parse Run Data + run_name="$col1" + dynamic_args_str="$col2" + + # 6.2.1 Skip if already completed (RESUME mode) + if [[ "$RESUME" == "true" && -n "${_completed_runs[$run_name]+x}" ]]; then + ((skipped_completed++)) + continue + fi + + # 6.3 Update Progress + ((run_counter++)) + if (( run_counter % one_percent_step == 0 )); then + if [ "$total_runs" -gt 0 ]; then + percent=$(( (run_counter * 100) / total_runs )) + else + percent=0 + fi + echo "📊 Progress: ${percent}% completed ($run_counter / $total_runs runs launched)" + fi + + # 6.4 Find a Free GPU Slot + assigned_slot=-1 + while [ "$assigned_slot" -eq -1 ]; do + for i in "${!gpus[@]}"; do + pid="${slot_pids[$i]}" + if [ "$pid" -eq 0 ] || ! kill -0 "$pid" 2>/dev/null; then + assigned_slot=$i + break + fi + done + if [ "$assigned_slot" -eq -1 ]; then + wait -n + fi + done + + # 6.5 Prepare Command + current_gpu=${gpus[$assigned_slot]} + read -ra DYNAMIC_ARGS_ARRAY <<< "$dynamic_args_str" + + # Extract dataset name for dynamic W&B project + dataset_val="" + for arg in "${DYNAMIC_ARGS_ARRAY[@]}"; do + if [[ $arg == dataset=* ]]; then + dataset_full_path="${arg#*=}" + dataset_val=$(basename "$dataset_full_path") + break + fi + done + dynamic_project_name="${project_name}_${dataset_val}" + + cmd=( + "python" "-m" "topobench" + "${DYNAMIC_ARGS_ARRAY[@]}" + "${FIXED_ARGS[@]}" + "trainer.devices=[${current_gpu}]" + "+logger.wandb.entity=${wandb_entity}" + "logger.wandb.project=${dynamic_project_name}" + "+logger.wandb.name=${run_name}" + ) + + # 6.6 Execute + run_and_log "${cmd[*]}" "$log_group" "$run_name" "$LOG_DIR" & + slot_pids[$assigned_slot]=$! + +done < <(generate_combinations) + + +# ============================================================================== +# SECTION 7: CLEANUP +# ============================================================================== +echo "----------------------------------------------------------" +echo " All jobs launched ($run_counter total, $skipped_completed skipped as already completed)." +echo " Waiting for remaining background jobs to finish..." +echo "----------------------------------------------------------" +wait diff --git a/scripts/topotune.sh b/scripts/topotune.sh new file mode 100755 index 000000000..ed69e6499 --- /dev/null +++ b/scripts/topotune.sh @@ -0,0 +1,470 @@ +#!/bin/bash +# ============================================================================== +# SCRIPT: topotune.sh +# DESCRIPTION: +# Runs a scalable hyperparameter sweep for TopoTune models across both +# simplicial and cellular domains. +# - ARCHITECTURE: Uses a "Cartesian Product" generation strategy. +# - CONCURRENCY: Uses "Virtual Slots" to run N jobs per GPU. +# - ORDERING: Prioritizes running all seeds for a config before moving on. +# - FILTERING: Skips invalid model+dataset combos (cell + simplicial data). +# ============================================================================== +# DO NOT MISS THIS + +export SELECTED_GPUS="0,1,2,3,4,5,6,7" +wandb_entity="gbg141-hopse" +RESUME=false # Set to true to skip already-completed runs (reads SUCCESSFUL_RUNS.log) + +# ============================================================================== +# SECTION 1: LOGGING & ENVIRONMENT SETUP +# ============================================================================== + +# 1.1 Define Project Identifiers +script_name="$(basename "${BASH_SOURCE[0]}" .sh)" +project_name="${script_name}" +log_group="topotune_sweep" +LOG_DIR="./logs/${log_group}" + +echo "==========================================================" +echo " Preparing log directory: $LOG_DIR" +echo "==========================================================" + +# 1.2 Log directory management +if [[ "$RESUME" == "true" ]]; then + echo "⏩ RESUME MODE: Keeping existing logs." + mkdir -p "$LOG_DIR" +else + if [ -d "$LOG_DIR" ]; then rm -r "$LOG_DIR"; fi + mkdir -p "$LOG_DIR" +fi + +# 1.3 Robust Dependency Loading +SCRIPT_DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" &> /dev/null && pwd )" +export HYDRA_FULL_ERROR=1 + +find_logging_script() { + local dir="$1" + while [[ "$dir" != "/" ]]; do + if [[ -f "$dir/base/logging.sh" ]]; then echo "$dir/base/logging.sh"; return 0; fi + if [[ -f "$dir/scripts/base/logging.sh" ]]; then echo "$dir/scripts/base/logging.sh"; return 0; fi + dir="$(dirname "$dir")" + done + return 1 +} + +LOGGING_PATH=$(find_logging_script "$SCRIPT_DIR") +if [[ -n "$LOGGING_PATH" ]]; then + echo "✔ Found logging utils at: $LOGGING_PATH" + source "$LOGGING_PATH" +else + echo "❌ CRITICAL ERROR: Could not locate 'base/logging.sh'." + exit 1 +fi + +# ============================================================================== +# CPU THREAD LIMITS (Crucial for concurrency) +# ============================================================================== +export OMP_NUM_THREADS=1 +export MKL_NUM_THREADS=1 +export OPENBLAS_NUM_THREADS=1 +export VECLIB_MAXIMUM_THREADS=1 +export NUMEXPR_NUM_THREADS=1 +# ============================================================================== +# SECTION 2: HARDWARE & CONCURRENCY (Auto-Detected) +# ============================================================================== + +# 2.1 Auto-detect GPUs and determine jobs-per-GPU from VRAM. +# Output format: "JOBS_PER_GPU gpu_id_0 gpu_id_1 ..." +# Thresholds: >= 80 GB -> 4 jobs, <= 30 GB -> 2 jobs, between -> 3 jobs. +_gpu_info=$(python3 -c " +import subprocess +import os + +# 1. Read the allowed GPUs from the environment variable +selected_env = os.environ.get('SELECTED_GPUS', '').strip() +allowed_gpus = [x.strip() for x in selected_env.split(',')] if selected_env else None + +try: + out = subprocess.check_output( + ['nvidia-smi', '--query-gpu=index,memory.total', '--format=csv,noheader,nounits'], + text=True + ) + indices, mem_mb = [], [] + for line in out.strip().splitlines(): + idx, mem = line.split(',') + idx = idx.strip() + + # 2. Skip this GPU if it's not in our selected list + if allowed_gpus and idx not in allowed_gpus: + continue + + indices.append(idx) + mem_mb.append(int(mem.strip())) + + # Safety check in case the selected GPUs don't exist + if not indices: + print('0') + exit(0) + + min_mem_gb = min(mem_mb) / 1024 + if min_mem_gb >= 80: + jobs = 4 + elif min_mem_gb <= 10: + jobs = 1 + elif min_mem_gb <= 30: + jobs = 2 + else: + jobs = 3 + + print(jobs, ' '.join(indices)) +except Exception: + print('2 0') +") +read -r JOBS_PER_GPU _gpu_ids <<< "$_gpu_info" +read -ra physical_gpus <<< "$_gpu_ids" + +echo "✔ Detected ${#physical_gpus[@]} GPU(s): ${physical_gpus[*]}" +echo "✔ Jobs per GPU: $JOBS_PER_GPU" + +# 2.2 Create Virtual Slots +gpus=() +for gpu in "${physical_gpus[@]}"; do + for ((i=1; i<=JOBS_PER_GPU; i++)); do gpus+=("$gpu"); done +done +echo "✔ Total virtual slots: ${#gpus[@]}" + +# 2.3 Initialize Slot Tracking +declare -a slot_pids +for i in "${!gpus[@]}"; do slot_pids[$i]=0; done + + +# ============================================================================== +# SECTION 3: EXPERIMENT PARAMETERS +# ============================================================================== + +# --- Models (both domains) --- +# Use "alias::hydra_value" to disambiguate run names (both share basename "topotune"). +models=( + "sim_topotune::simplicial/topotune" + "cell_topotune::cell/topotune" +) + +# --- Datasets (same as hopse_m) --- +datasets=( + # "graph/MUTAG" + # "graph/cocitation_cora" + # "graph/PROTEINS" + # "graph/NCI1" + # "graph/NCI109" + # "graph/ZINC" + # "graph/cocitation_citeseer" + # "graph/cocitation_pubmed" + "simplicial/mantra_name" + "simplicial/mantra_orientation" + "simplicial/mantra_betti_numbers" +) + +# --- Neighborhoods (same as hopse_m) --- +# Use "alias::hydra_value" format for readable run names. +neighborhoods=( + "adj1::[up_adjacency-0]" + "adj2::[up_adjacency-0,2-up_adjacency-0]" + "adj3::[up_adjacency-0,up_adjacency-1,2-up_adjacency-0,down_adjacency-1,down_adjacency-2,2-down_adjacency-2]" + "inc1::[up_incidence-0,2-up_incidence-0]" + "inc2::[up_incidence-0,up_incidence-1,2-up_incidence-0,down_incidence-1,down_incidence-2,2-down_incidence-2]" +) + +# --- Hyperparameters (same as hopse_m) --- +gnn_num_layers=(1 2 4) +hidden_channels=(128 256) +proj_dropouts=(0.25 0.5) +lrs=(0.01 0.001) +weight_decays=(0 0.0001) +batch_sizes=(128 256) +DATA_SEEDS=(0 3 5 7 9) + +# --- Fixed Parameters --- +FIXED_ARGS=( + "trainer.max_epochs=500" + "trainer.min_epochs=50" + "trainer.check_val_every_n_epoch=5" + "callbacks.early_stopping.patience=10" + "delete_checkpoint_after_test=True" +) + + +# ============================================================================== +# SECTION 4: SWEEP CONFIGURATION MAPPING (CRITICAL ORDERING) +# Format: "ShortTag | HydraKey | ${Array[*]}" +# +# Values support an optional "alias::hydra_value" syntax for readable names. +# The generator also filters out invalid model+dataset combos. +# +# Key differences from hopse_m: +# - Neighborhoods: model.backbone.neighborhoods (not model.preprocessing_params.neighborhoods) +# - GNN layers: model.backbone.GNN.num_layers (not model.backbone.n_layers) +# - No encodings dimension (TopoTune does not use HOPSE PSE/FE encodings) +# ============================================================================== + +SWEEP_CONFIG=( + # --- LEVEL 1: SLOWEST CHANGING (Outer Loops) --- + "|model|${models[*]}" + "|dataset|${datasets[*]}" + "N|model.backbone.neighborhoods|${neighborhoods[*]}" + + # --- LEVEL 2: HYPERPARAMETERS --- + "L|model.backbone.GNN.num_layers|${gnn_num_layers[*]}" + "h|model.feature_encoder.out_channels|${hidden_channels[*]}" + "pdro|model.feature_encoder.proj_dropout|${proj_dropouts[*]}" + "lr|optimizer.parameters.lr|${lrs[*]}" + "wd|optimizer.parameters.weight_decay|${weight_decays[*]}" + "bs|dataset.dataloader_params.batch_size|${batch_sizes[*]}" + + # --- LEVEL 3: FASTEST CHANGING (Inner Loop) --- + "seed|dataset.split_params.data_seed|${DATA_SEEDS[*]}" +) + + +# ============================================================================== +# SECTION 5: PYTHON GENERATOR (Smart Transductive Filtering) +# ============================================================================== + +# Define where your dataset YAMLs live so the generator can inspect them. +# UPDATE THIS PATH IF YOUR CONFIGS ARE STORED ELSEWHERE. +export CONFIG_DIR="./configs/dataset" + +generate_combinations() { +python3 -c " +import sys, itertools, os + +config_dir = os.environ.get('CONFIG_DIR', './configs/dataset') + +# 1. Parse Input Specs +specs = [] +for item in sys.argv[1:]: + parts = item.split('|') + tag = parts[0].strip() + key = parts[1].strip() + vals = parts[2].split() + specs.append({'tag': tag, 'key': key, 'vals': vals}) + +# 2. Generate Cartesian Product +options = [[(s['tag'], s['key'], val) for val in s['vals']] for s in specs] +combinations = list(itertools.product(*options)) + +# Helper to strip alias +def hydra_val(v): + return v.split('::', 1)[1] if '::' in v else v + +# Find the first batch size in the sweep so we don't duplicate transductive runs +bs_key = 'dataset.dataloader_params.batch_size' +bs_spec = next((s for s in specs if s['key'] == bs_key), None) +first_bs = hydra_val(bs_spec['vals'][0]) if bs_spec else None + +# 3. Filter and Mutate Combos +valid = [] +skipped = 0 +transductive_cache = {} + +for combo in combinations: + vals_dict = {key: hydra_val(val) for (_, key, val) in combo} + model_val = vals_dict.get('model', '') + dataset_val = vals_dict.get('dataset', '') + current_bs = vals_dict.get(bs_key, '') + + # --- Rule A: Skip cell model + simplicial dataset --- + if model_val.startswith('cell/') and dataset_val.startswith('simplicial/'): + skipped += 1 + continue + + # --- Rule B: Transductive Batch Size Handler --- + is_transductive = False + if dataset_val in transductive_cache: + is_transductive = transductive_cache[dataset_val] + else: + # Construct path to yaml (e.g., ./configs/dataset/graph/cocitation_cora.yaml) + yaml_path = os.path.join(config_dir, f'{dataset_val}.yaml') + if os.path.exists(yaml_path): + with open(yaml_path, 'r') as f: + # Fast text check avoids needing pip install pyyaml + if 'learning_setting: transductive' in f.read(): + is_transductive = True + else: + print(f'⚠️ WARNING: Could not find config at {yaml_path}', file=sys.stderr) + + transductive_cache[dataset_val] = is_transductive + + if is_transductive: + # If this isn't the first batch size in the sweep list, skip it + # to avoid running the exact same bs=1 experiment multiple times. + if current_bs != first_bs: + skipped += 1 + continue + + # Mutate the current combination to force batch_size to 1 + new_combo = [] + for (tag, key, val) in combo: + if key == bs_key: + # Force the value to 1. If an alias was used, keep it clean. + new_combo.append((tag, key, '1')) + else: + new_combo.append((tag, key, val)) + combo = tuple(new_combo) + + valid.append(combo) + +# 4. Print header +print(f'TOTAL;{len(valid)}') +if skipped: + print(f'SKIPPED;{skipped}', file=sys.stderr) + +# 5. Print each valid combination +for combo in valid: + name_parts = [] + cmd_args = [] + for (tag, key, val) in combo: + if '::' in val: + alias, hydra_val_str = val.split('::', 1) + clean_val = alias + actual_val = hydra_val_str + else: + clean_val = os.path.basename(val) + actual_val = val + + if tag: + name_parts.append(f'{tag}{clean_val}') + else: + name_parts.append(clean_val) + cmd_args.append(f'{key}={actual_val}') + + run_name = '_'.join(name_parts) + print(f'{run_name};' + ' '.join(cmd_args)) +" "${SWEEP_CONFIG[@]}" +} + +# ============================================================================== +# SECTION 5.5: RESUME — LOAD COMPLETED RUNS +# ============================================================================== + +declare -A _completed_runs +if [[ "$RESUME" == "true" ]]; then + # run_and_log nests: $LOG_DIR/$log_group/SUCCESSFUL_RUNS.log + _success_log="$LOG_DIR/$log_group/SUCCESSFUL_RUNS.log" + if [[ -f "$_success_log" ]]; then + while IFS= read -r _line; do + # Format: "DATE: [SUCCESS] run_name" + _rname="${_line##*\[SUCCESS\] }" + _completed_runs["$_rname"]=1 + done < "$_success_log" + echo "✔ Loaded ${#_completed_runs[@]} completed runs to skip." + else + echo "⚠️ No SUCCESSFUL_RUNS.log found at $_success_log — nothing to skip." + fi +fi + +# ============================================================================== +# SECTION 6: MAIN EXECUTION LOOP +# ============================================================================== + +echo "----------------------------------------------------------" +echo " Generating experiment combinations..." +echo "----------------------------------------------------------" + +total_runs=0 +run_counter=0 +skipped_completed=0 +one_percent_step=1 + +while IFS=";" read -r col1 col2; do + + # 6.1 Handle Header + if [[ "$col1" == "TOTAL" ]]; then + total_runs=$col2 + if [ "$total_runs" -gt 0 ]; then + one_percent_step=$(( total_runs / 100 )) + fi + if [ "$one_percent_step" -eq 0 ]; then one_percent_step=1; fi + + echo "► Total runs planned: $total_runs" + echo "► Reporting progress every $one_percent_step runs (1%)" + echo "----------------------------------------------------------" + continue + fi + + # 6.2 Parse Run Data + run_name="$col1" + dynamic_args_str="$col2" + + # 6.2.1 Skip if already completed (RESUME mode) + if [[ "$RESUME" == "true" && -n "${_completed_runs[$run_name]+x}" ]]; then + ((skipped_completed++)) + continue + fi + + # 6.3 Update Progress + ((run_counter++)) + if (( run_counter % one_percent_step == 0 )); then + if [ "$total_runs" -gt 0 ]; then + percent=$(( (run_counter * 100) / total_runs )) + else + percent=0 + fi + echo "📊 Progress: ${percent}% completed ($run_counter / $total_runs runs launched)" + fi + + # 6.4 Find a Free GPU Slot + assigned_slot=-1 + while [ "$assigned_slot" -eq -1 ]; do + for i in "${!gpus[@]}"; do + pid="${slot_pids[$i]}" + if [ "$pid" -eq 0 ] || ! kill -0 "$pid" 2>/dev/null; then + assigned_slot=$i + break + fi + done + if [ "$assigned_slot" -eq -1 ]; then + wait -n + fi + done + + # 6.5 Prepare Command + current_gpu=${gpus[$assigned_slot]} + read -ra DYNAMIC_ARGS_ARRAY <<< "$dynamic_args_str" + + # --- Extract dataset name for dynamic W&B project --- + dataset_val="" + for arg in "${DYNAMIC_ARGS_ARRAY[@]}"; do + if [[ $arg == dataset=* ]]; then + dataset_full_path="${arg#*=}" + dataset_val=$(basename "$dataset_full_path") + break + fi + done + dynamic_project_name="${project_name}_${dataset_val}" + + cmd=( + "python" "-m" "topobench" + "${DYNAMIC_ARGS_ARRAY[@]}" + "${FIXED_ARGS[@]}" + "trainer.devices=[${current_gpu}]" + "+logger.wandb.entity=${wandb_entity}" + "logger.wandb.project=${dynamic_project_name}" + "+logger.wandb.name=${run_name}" + ) + + # 6.6 Execute + run_and_log "${cmd[*]}" "$log_group" "$run_name" "$LOG_DIR" & + slot_pids[$assigned_slot]=$! + +done < <(generate_combinations) + + +# ============================================================================== +# SECTION 7: CLEANUP +# ============================================================================== +echo "----------------------------------------------------------" +echo " All jobs launched ($run_counter total, $skipped_completed skipped as already completed)." +echo " Waiting for remaining background jobs to finish..." +echo "----------------------------------------------------------" +wait +echo "✔ All runs complete." diff --git a/test/utils/test_config_resolvers.py b/test/utils/test_config_resolvers.py index 67d024858..e2dddc65a 100644 --- a/test/utils/test_config_resolvers.py +++ b/test/utils/test_config_resolvers.py @@ -212,6 +212,20 @@ def test_infer_in_khop_feature_dim(self): out = infer_in_khop_feature_dim(dataset_in_channels, max_hop) assert out == [[7, 14, 42, 133], [7, 28, 91, 294], [7, 21, 70, 231]] + def test_infer_in_khop_feature_dim_with_complex_dim(self): + """Test infer_in_khop_feature_dim with complex_dim truncation.""" + # dataset_in_channels has 4 elements (from lifting complex_dim=3) + # but transform only processes 3 ranks (complex_dim=3) + dataset_in_channels = [7, 7, 7, 7] + max_hop = 2 + # Without truncation: rank 2 hop 1 = 28 (wrong, includes rank 3 neighbor) + out_no_trunc = infer_in_khop_feature_dim(dataset_in_channels, max_hop) + assert out_no_trunc[2][1] == 28 + # With truncation: rank 2 hop 1 = 21 (correct, no rank 3 neighbor) + out_trunc = infer_in_khop_feature_dim(dataset_in_channels, max_hop, complex_dim=3) + assert out_trunc[2][1] == 21 + assert len(out_trunc) == 3 + def test_check_pses_in_transforms_empty(self): """Test check_pses_in_transforms with no encodings.""" transforms = OmegaConf.create({}) @@ -789,7 +803,7 @@ def test_check_fes_keyed_sheaf(self): assert check_fes_in_transforms(transforms) == 9 def test_check_fes_combined_fes_pprfe_and_sheaf(self): - """CombinedFEs inner loop: PPRFE list + SheafConnLapPE.""" + """Test CombinedFEs inner loop with PPRFE list and SheafConnLapPE.""" transforms = OmegaConf.create( { "CombinedFEs": { @@ -808,7 +822,7 @@ def test_check_fes_combined_fes_pprfe_and_sheaf(self): assert check_fes_in_transforms(transforms) == 7 + 4 def test_check_fes_combined_fes_pprfe_default_alpha(self): - """CombinedFEs PPRFE: missing alpha_param uses default [0.1, 10] -> 10.""" + """Test CombinedFEs PPRFE with missing alpha_param using default [0.1, 10].""" transforms = OmegaConf.create( { "CombinedFEs": { @@ -820,7 +834,7 @@ def test_check_fes_combined_fes_pprfe_default_alpha(self): assert check_fes_in_transforms(transforms) == 10 def test_check_fes_combined_fes_pprfe_scalar_alpha(self): - """CombinedFEs PPRFE: scalar alpha_param.""" + """Test CombinedFEs PPRFE with scalar alpha_param.""" transforms = OmegaConf.create( { "CombinedFEs": { @@ -834,42 +848,42 @@ def test_check_fes_combined_fes_pprfe_scalar_alpha(self): assert check_fes_in_transforms(transforms) == 11 def test_get_fes_dimensions_khopfe(self): - """get_fes_dimensions: KHopFE uses max_hop - 1.""" + """Test get_fes_dimensions with KHopFE using max_hop - 1.""" encodings = ["KHopFE"] parameters = {"KHopFE": {"max_hop": 5}} assert get_fes_dimensions(encodings, parameters) == [4] def test_get_fes_dimensions_pprfe_list_tuple(self): - """get_fes_dimensions: PPRFE alpha as tuple -> second element.""" + """Test get_fes_dimensions with PPRFE alpha as tuple returning second element.""" encodings = ["PPRFE"] parameters = {"PPRFE": {"alpha_param_PPRFE": (0.1, 6)}} assert get_fes_dimensions(encodings, parameters) == [6] def test_get_fes_dimensions_pprfe_omegaconf_list(self): - """get_fes_dimensions: PPRFE alpha as OmegaConf list.""" + """Test get_fes_dimensions with PPRFE alpha as OmegaConf list.""" parameters = OmegaConf.create( {"PPRFE": {"alpha_param_PPRFE": [0.1, 12]}} ) assert get_fes_dimensions(["PPRFE"], parameters) == [12] def test_get_fes_dimensions_pprfe_scalar(self): - """get_fes_dimensions: PPRFE scalar alpha.""" + """Test get_fes_dimensions with PPRFE scalar alpha.""" encodings = ["PPRFE"] parameters = {"PPRFE": {"alpha_param_PPRFE": 5}} assert get_fes_dimensions(encodings, parameters) == [5] def test_get_fes_dimensions_pprfe_missing_uses_default(self): - """get_fes_dimensions: missing PPRFE block uses default alpha upper bound 10.""" + """Test get_fes_dimensions with missing PPRFE block using default alpha upper bound 10.""" assert get_fes_dimensions(["PPRFE"], {}) == [10] def test_get_fes_dimensions_sheaf(self): - """get_fes_dimensions: SheafConnLapPE.""" + """Test get_fes_dimensions with SheafConnLapPE.""" encodings = ["SheafConnLapPE"] parameters = {"SheafConnLapPE": {"max_pe_dim": 8}} assert get_fes_dimensions(encodings, parameters) == [8] def test_get_all_encoding_dimensions_khopfe_pprfe_sheaf(self): - """get_all_encoding_dimensions: FE branches KHopFE, PPRFE list, SheafConnLapPE.""" + """Test get_all_encoding_dimensions with KHopFE, PPRFE list, and SheafConnLapPE branches.""" encodings = ["KHopFE", "PPRFE", "SheafConnLapPE"] parameters = { "KHopFE": {"max_hop": 4}, @@ -879,11 +893,11 @@ def test_get_all_encoding_dimensions_khopfe_pprfe_sheaf(self): assert get_all_encoding_dimensions(encodings, parameters) == [3, 9, 3] def test_get_all_encoding_dimensions_pprfe_scalar(self): - """get_all_encoding_dimensions: PPRFE scalar alpha.""" + """Test get_all_encoding_dimensions with PPRFE scalar alpha.""" assert get_all_encoding_dimensions( ["PPRFE"], {"PPRFE": {"alpha_param_PPRFE": 2}} ) == [2] def test_get_all_encoding_dimensions_pprfe_missing_uses_default(self): - """get_all_encoding_dimensions: PPRFE absent from parameters -> default 10.""" + """Test get_all_encoding_dimensions with PPRFE absent from parameters using default 10.""" assert get_all_encoding_dimensions(["PPRFE"], {}) == [10] diff --git a/topobench/data/loaders/graph/adme_datasets.py b/topobench/data/loaders/graph/adme_datasets.py new file mode 100644 index 000000000..ed9dbf7c8 --- /dev/null +++ b/topobench/data/loaders/graph/adme_datasets.py @@ -0,0 +1,220 @@ +"""Loaders for TDC (Therapeutics Data Commons) ADME datasets with SMILES to graph conversion. +""" + +import os +from pathlib import Path + +import torch +from ogb.utils import smiles2graph +from omegaconf import DictConfig +from tdc.single_pred import ADME +from torch_geometric.data import Data, InMemoryDataset + +from topobench.data.loaders.base import AbstractLoader + + +class ADMEDatasetLoader(AbstractLoader): + """Load TDC ADME datasets with SMILES to graph conversion using OGB featurization. + + This loader: + 1. Loads ADME datasets from TDC (Therapeutics Data Commons) + 2. Converts SMILES strings to PyG graphs using OGB's standard featurization + 3. Uses fixed scaffold splits from TDC + 4. Returns graphs compatible with OGB molecular property prediction + + Node features (9-dimensional): + - Atomic number + - Chirality + - Degree + - Formal charge + - Number of hydrogens + - Number of radical electrons + - Hybridization + - Is aromatic + - Is in ring + + Edge features (3-dimensional): + - Bond type + - Bond stereochemistry + - Is conjugated + + Parameters + ---------- + parameters : DictConfig + Configuration parameters containing: + - data_dir: Root directory for data + - data_name: Name of the ADME dataset + - data_type: Type of the dataset (e.g., "ADME") + """ + + def __init__(self, parameters: DictConfig) -> None: + super().__init__(parameters) + + def load_dataset(self) -> InMemoryDataset: + """Load the ADME dataset with predefined scaffold splits. + + Returns + ------- + InMemoryDataset + The dataset with converted graphs and predefined splits. + + Raises + ------ + RuntimeError + If dataset loading or SMILES conversion fails. + ValueError + If invalid SMILES strings are encountered. + """ + + class _ADMEDataset(InMemoryDataset): + """Internal InMemoryDataset for ADME data.""" + + def __init__(self, root, data_name, split_idx, graph_list): + self.data_name = data_name + self.split_idx = split_idx + self._graph_list = graph_list + super().__init__(root) + self.data, self.slices = torch.load(self.processed_paths[0]) + + @property + def processed_file_names(self): + return [f"{self.data_name}.pt"] + + def process(self): + self.data, self.slices = self.collate(self._graph_list) + torch.save((self.data, self.slices), self.processed_paths[0]) + + def __repr__(self): + return f"ADMEDataset({self.data_name}, {len(self)})" + # Define which datasets are classification vs regression + CLASSIFICATION_DATASETS = { + # Absorption + "PAMPA_NCATS", + "HIA_Hou", + "Pgp_Broccatelli", + "Bioavailability_Ma", + # Distribution + "BBB_Martins", + # Metabolism - CYP Inhibition + "CYP1A2_Veith", + "CYP2C9_Veith", + "CYP2C19_Veith", + "CYP2D6_Veith", + "CYP3A4_Veith", + # Metabolism - CYP Substrate + "CYP2C9_Substrate_CarbonMangels", + "CYP2D6_Substrate_CarbonMangels", + "CYP3A4_Substrate_CarbonMangels", + } + + REGRESSION_DATASETS = { + # Absorption + "Caco2_Wang", + "Lipophilicity_AstraZeneca", + "Solubility_AqSolDB", + "HydrationFreeEnergy_FreeSolv", + # Distribution + "PPBR_AZ", + "VDss_Lombardo", + # Excretion + "Half_Life_Obach", + "Clearance_Hepatocyte_AZ", + "Clearance_Microsome_AZ", + } + + # Determine task type + dataset_name = self.parameters.data_name + if dataset_name in CLASSIFICATION_DATASETS: + is_classification = True + elif dataset_name in REGRESSION_DATASETS: + is_classification = False + else: + raise ValueError( + f"Unknown ADME dataset: {dataset_name}. " + f"Please add it to CLASSIFICATION_DATASETS or REGRESSION_DATASETS." + ) + + # Create raw data directory for TDC to download to + raw_dir = os.path.join(self.root_data_dir, dataset_name, "raw") + os.makedirs(raw_dir, exist_ok=True) + + # Load data from TDC with scaffold split, specify path for downloads + data = ADME(name=dataset_name, path=raw_dir) + split = data.get_split() + + # Convert splits to graphs + graph_list = [] + train_data = split["train"] + valid_data = split["valid"] + test_data = split["test"] + + # Process each split + for split_data in [train_data, valid_data, test_data]: + for _, row in split_data.iterrows(): + smiles = row["Drug"] + label = row["Y"] + + # Convert SMILES to graph using OGB's standard featurization + graph_dict = smiles2graph(smiles) + + # Create PyG Data object + if is_classification: + label_tensor = torch.tensor( + int(label), dtype=torch.long + ) + else: + label_tensor = torch.tensor([label], dtype=torch.float) + + pyg_graph = Data( + x=torch.tensor( + graph_dict["node_feat"], dtype=torch.float + ), + edge_index=torch.tensor( + graph_dict["edge_index"], dtype=torch.long + ), + edge_attr=torch.tensor( + graph_dict["edge_feat"], dtype=torch.float + ), + y=label_tensor, + num_nodes=graph_dict["num_nodes"], + ) + + graph_list.append(pyg_graph) + + # Prepare split indices + split_idx = { + "train": torch.arange(len(train_data)), + "valid": torch.arange( + len(train_data), len(train_data) + len(valid_data) + ), + "test": torch.arange( + len(train_data) + len(valid_data), + len(train_data) + len(valid_data) + len(test_data), + ), + } + + # Create dataset - point to data/graph/ADME/{dataset_name} + dataset_root = os.path.join(self.root_data_dir, dataset_name) + dataset = _ADMEDataset( + root=dataset_root, + data_name=self.parameters.data_name, + split_idx=split_idx, + graph_list=graph_list, + ) + + # Attach split_idx to the dataset for compatibility with framework + dataset.split_idx = split_idx + + return dataset + + def get_data_dir(self) -> Path: + """Get the data directory. + + Returns + ------- + Path + The path to the dataset directory. + Format: {root_data_dir}/{dataset_name}/ + Example: data/graph/ADME/BBB_Martins/ + """ + return os.path.join(self.root_data_dir, self.parameters.data_name) diff --git a/topobench/data/preprocessor/preprocessor.py b/topobench/data/preprocessor/preprocessor.py index c59cbd529..529ce8f2e 100644 --- a/topobench/data/preprocessor/preprocessor.py +++ b/topobench/data/preprocessor/preprocessor.py @@ -6,6 +6,7 @@ import torch import torch_geometric +from filelock import FileLock from torch_geometric.io import fs from tqdm import tqdm @@ -42,17 +43,28 @@ def __init__(self, dataset, data_dir, transforms_config=None, **kwargs): pre_transform = self.instantiate_pre_transform( data_dir, transforms_config ) - # Record the time taken for preprocessing - start_time = time.time() - super().__init__( - self.processed_data_dir, None, pre_transform, **kwargs + + # 1. Ensure the target directory exists so we can place a lock file in it + os.makedirs(self.processed_data_dir, exist_ok=True) + lock_path = os.path.join( + self.processed_data_dir, "preprocessing.lock" ) + + start_time = time.time() + + with FileLock(lock_path): + # When Process 1 finishes, Process 2 checks, sees data.pt, and skips. + super().__init__( + self.processed_data_dir, None, pre_transform, **kwargs + ) + self.save_transform_parameters() + end_time = time.time() self.preprocessing_time = end_time - start_time + self.transform = ( dataset.transform if hasattr(dataset, "transform") else None ) - self.save_transform_parameters() self.load(self.processed_paths[0]) self.data_list = [data for data in self] else: @@ -110,26 +122,53 @@ def instantiate_pre_transform( torch_geometric.transforms.Compose Pre-transform object. """ + from torch_geometric.transforms import ToDevice + if transforms_config.keys() == {"liftings"}: transforms_config = transforms_config.liftings - # Check if this is a single transform config (has transform_name key) - # or multiple transforms config (each value is a dict with transform_name) + if "transform_name" in transforms_config: - # Single transform configuration - pre_transforms_dict = { - transforms_config.transform_name: DataTransform( - **transforms_config - ) - } + config_items = [ + (transforms_config.transform_name, transforms_config) + ] else: - # Multiple transforms configuration - pre_transforms_dict = { - key: DataTransform(**value) - for key, value in transforms_config.items() - } + config_items = transforms_config.items() + + pre_transforms_list = [] + pre_transforms_dict = {} + + # Track where the graph currently lives in the pipeline + current_device = "cpu" + + for key, value in config_items: + kwargs = dict(value) + + requested_device = kwargs.pop("preprocessor_device", "cpu") + + target_device = ( + "cuda" + if requested_device == "cuda" and torch.cuda.is_available() + else "cpu" + ) + + transform_instance = DataTransform(**kwargs) + pre_transforms_dict[key] = transform_instance + + if target_device != current_device: + pre_transforms_list.append(ToDevice(target_device)) + current_device = target_device + + pre_transforms_list.append(transform_instance) + + # If the pipeline ends while the graph is still on the GPU, + # we MUST pull it back to the CPU before PyTorch Geometric saves it to disk. + if current_device == "cuda": + pre_transforms_list.append(ToDevice("cpu")) + pre_transforms = torch_geometric.transforms.Compose( - list(pre_transforms_dict.values()) + pre_transforms_list ) + self.set_processed_data_dir( pre_transforms_dict, data_dir, transforms_config ) @@ -198,7 +237,9 @@ def process(self) -> None: print(f"\nApplying transforms to {len(data_list)} graphs...") self.data_list = [ self.pre_transform(d) - for d in tqdm(data_list, desc="Processing graphs", unit="graph") + for d in tqdm( + data_list, desc="Processing graphs", unit="graph" + ) ] else: self.data_list = data_list diff --git a/topobench/run.py b/topobench/run.py index 37c3de291..d9a217903 100755 --- a/topobench/run.py +++ b/topobench/run.py @@ -1,5 +1,6 @@ """Main entry point for training and testing models.""" +import os import random from pathlib import Path from typing import Any @@ -166,6 +167,7 @@ infer_list_length_plus_one, replace=True, ) +OmegaConf.register_new_resolver("pid", lambda: os.getpid()) def initialize_hydra() -> DictConfig: @@ -409,6 +411,19 @@ def rerun_best_model_checkpoint( if isinstance(lgr, WandbLogger): lgr.log_metrics(logged) + if ( + cfg.get("delete_checkpoint_after_test", False) + and model_path + and model_path.exists() + ): + log.info(f"Cleaning up: Deleting checkpoint at {model_path}") + try: + model_path.unlink() + except Exception as e: + log.warning( + f"Failed to delete checkpoint at {model_path}. Error: {e}" + ) + def count_number_of_parameters( model: torch.nn.Module, only_trainable: bool = True diff --git a/topobench/transforms/data_manipulations/combined_positional_and_structural_encodings.py b/topobench/transforms/data_manipulations/combined_positional_and_structural_encodings.py index 6eb2fba42..75f368b4a 100644 --- a/topobench/transforms/data_manipulations/combined_positional_and_structural_encodings.py +++ b/topobench/transforms/data_manipulations/combined_positional_and_structural_encodings.py @@ -1,5 +1,6 @@ """Combined Positional and Structural Encodings Transform.""" +import torch from torch_geometric.data import Data from torch_geometric.transforms import BaseTransform @@ -19,10 +20,13 @@ class CombinedPSEs(BaseTransform): ---------- encodings : list of str List of structural encodings to apply. Supported values are - "LapPE" for Laplacian Positional Encoding and "RWSE" for - Random Walk Structural Encoding. + "LapPE", "RWSE", "ElectrostaticPE", and "HKdiagSE". parameters : dict, optional Additional parameters for the encoding transforms. + preprocessor_device : str, optional + The overarching device to use for the combined transforms (e.g., 'cpu', 'cuda'). + If a specific encoding specifies its own device in `parameters`, that will + take precedence. Default is None. **kwargs : dict, optional Additional keyword arguments. """ @@ -31,10 +35,12 @@ def __init__( self, encodings: list[str], parameters: dict | None = None, + preprocessor_device: str | None = None, **kwargs, ): self.encodings = encodings self.parameters = parameters if parameters is not None else {} + self.device = preprocessor_device def forward(self, data: Data) -> Data: r"""Apply the transform to the input data. @@ -73,13 +79,45 @@ def forward(self, data: Data) -> Data: f"Missing in PSE_ENCODINGS: {missing_in_set}." ) + if hasattr(data, "edge_index") and data.edge_index is not None: + baseline_device = data.edge_index.device + elif hasattr(data, "x") and data.x is not None: + baseline_device = data.x.device + else: + baseline_device = torch.device("cpu") + + current_device = baseline_device + for enc in self.encodings: if enc not in encoding_classes: raise ValueError(f"Unsupported encoding type: {enc}") - encoder = encoding_classes[enc](**self.parameters.get(enc, {})) + enc_params = self.parameters.get(enc, {}).copy() + + # Determine the target device for this specific transform + # Priority: 1. PE-specific device, 2. CombinedPSEs overarching device, 3. Baseline + req_device = enc_params.pop("device", self.device) + target_device = ( + torch.device(req_device) if req_device else baseline_device + ) + + # Fallback to CPU if CUDA is requested but physically unavailable + if target_device.type == "cuda" and not torch.cuda.is_available(): + target_device = torch.device("cpu") + + if current_device != target_device: + data = data.to(target_device) + current_device = target_device + + # Instantiate and apply the encoder + # The encoder naturally uses `current_device` because it reads `data.edge_index.device` + encoder = encoding_classes[enc](**enc_params) data = encoder(data) + # Ensure the graph is returned to its original device before exiting + if current_device != baseline_device: + data = data.to(baseline_device) + return data @@ -92,8 +130,8 @@ class SelectDestinationPSEs(BaseTransform): Parameters ---------- - encoding_key : str - Key in `data` where the PSEs are stored (e.g., 'LapPE', 'RWSE'). + encodings : list of str + Keys in `data` where the PSEs are stored (e.g., ['LapPE', 'RWSE']). **kwargs : dict, optional Additional keyword arguments. """ @@ -130,4 +168,18 @@ def forward(self, data: Data, n_dst_nodes: int) -> Data: return Data(**new_data) def __call__(self, data: Data, n_dst_nodes: int) -> Data: + """Override __call__ to accept n_dst_nodes as an argument. + + Parameters + ---------- + data : torch_geometric.data.Data + The input data. + n_dst_nodes : int + Number of destination nodes. + + Returns + ------- + torch_geometric.data.Data + The transformed data with selected PSEs. + """ return self.forward(data, n_dst_nodes) diff --git a/topobench/transforms/data_manipulations/electrostatic_encodings.py b/topobench/transforms/data_manipulations/electrostatic_encodings.py index 6bfd8bc25..02c289ba7 100644 --- a/topobench/transforms/data_manipulations/electrostatic_encodings.py +++ b/topobench/transforms/data_manipulations/electrostatic_encodings.py @@ -1,15 +1,13 @@ -"""Laplacian Positional Encoding (LapPE) Transform.""" +"""Electrostatic Positional Encoding (ElectrostaticPE) Transform.""" -from copy import deepcopy +import time -import numpy as np import torch from torch_geometric.data import Data from torch_geometric.transforms import BaseTransform from torch_geometric.utils import ( get_laplacian, remove_self_loops, - to_scipy_sparse_matrix, ) @@ -17,15 +15,20 @@ class ElectrostaticPE(BaseTransform): r""" Electrostatic Positional Encoding (ElectrostaticPE) transform. - Kernel based on the electrostatic interaction between nodes. - Parameters ---------- concat_to_x : bool, optional - If True, concatenates the encodings with existing node features in - ``data.x``. If ``data.x`` is None, creates it. Default is True. + If True, concatenates the encodings with existing node features. + Default is True. eps : float, optional - Small value to avoid division by zero. Default is 1e-6. + Small value to avoid division by zero. + Default is 1e-6. + method : str, optional + Computation method: "numpy" (CPU NumPy) or "gpu" (GPU PyTorch). + Default is "gpu". + debug : bool, optional + If True, runs both methods and compares outputs. + Default is False. **kwargs : dict Additional arguments (not used). """ @@ -34,26 +37,60 @@ def __init__( self, concat_to_x: bool = True, eps: float = 1e-6, + method: str = "numpy", + debug: bool = False, **kwargs, ): self.concat_to_x = concat_to_x self.eps = eps + self.method = method + self.debug = debug self.pe_dim = 7 + if method not in ["numpy", "gpu"]: + raise ValueError("Method must be 'numpy' or 'gpu'.") + def forward(self, data: Data) -> Data: - """Compute the Laplacian positional encodings for the input graph. + """Compute the electrostatic positional encodings for the input graph. Parameters ---------- - data : Data + data : torch_geometric.data.Data Input graph data object. Returns ------- - Data - Graph data object with Laplacian positional encodings added. + torch_geometric.data.Data + Graph data object with electrostatic positional encodings added. """ - pe = self._compute_electrostatic_pe(data.edge_index, data.num_nodes) + if self.debug: + print("\n--- ElectrostaticPE Debug Report ---") + print(f"Data device: {data.edge_index.device}") + # Exact Method (Original CPU NumPy) + t0 = time.time() + pe_numpy = self._compute_numpy(data.edge_index, data.num_nodes) + t_numpy = time.time() - t0 + print(f"Exact compute time: {t_numpy:.4f}s") + + # Fast Method (Pure PyTorch GPU) + t0 = time.time() + pe_gpu = self._compute_gpu(data.edge_index, data.num_nodes) + t_gpu = time.time() - t0 + print(f"Fast compute time: {t_gpu:.4f}s") + + # Compare (Only if non-zero) + diff = torch.abs(pe_numpy - pe_gpu) + speedup = (t_numpy / t_gpu) if t_gpu > 0 else float("inf") + print(f"Speedup Factor: {speedup:.2f}x") + print(f"Mean Abs Error: {diff.mean().item():.6e}") + print("------------------------------------\n") + + pe = pe_numpy if self.method == "numpy" else pe_gpu + else: + if self.method == "numpy": + pe = self._compute_numpy(data.edge_index, data.num_nodes) + else: + pe = self._compute_gpu(data.edge_index, data.num_nodes) if self.concat_to_x: if data.x is None: @@ -65,10 +102,10 @@ def forward(self, data: Data) -> Data: return data - def _compute_electrostatic_pe( + def _compute_gpu( self, edge_index: torch.Tensor, num_nodes: int ) -> torch.Tensor: - """Internal method to compute electrostatic positional encodings. + """Compute ElectrostaticPE using optimized pure-PyTorch implementation. Parameters ---------- @@ -80,14 +117,106 @@ def _compute_electrostatic_pe( Returns ------- torch.Tensor - Electrostatic positional encodings. + Electrostatic positional encodings of shape ``[num_nodes, 7]``. """ device = edge_index.device + if edge_index.size(1) == 0 or num_nodes <= 1: + return torch.zeros(num_nodes, self.pe_dim, device=device) + + # 1. Get Normalized Laplacian and make it dense immediately on device + edge_index_lap, edge_weight = get_laplacian( + edge_index, normalization="sym", num_nodes=num_nodes + ) + L = torch.sparse_coo_tensor( + edge_index_lap, + edge_weight.float(), + (num_nodes, num_nodes), + device=device, + ).to_dense() + + # 2. Efficiently compute DinvA without deepcopy or torch.eye + diag_L = L.diagonal() + Dinv_vec = 1.0 / (diag_L + 1e-6) + A = L.abs() + A.fill_diagonal_(0) + # Broadcasting [N, 1] * [N, N] applies the row-wise scalar multiplication identical to Dinv @ A + DinvA = Dinv_vec.unsqueeze(1) * A + + # 3. Hardware-accelerated eigendecomposition + evals, evecs = torch.linalg.eigh(L) + + # 4. Filter eigenvalues + mask = evals >= self.eps + if not mask.any(): + return torch.zeros( + num_nodes, self.pe_dim, dtype=torch.float32, device=device + ) + + evals_filtered = evals[mask] + evecs_filtered = evecs[:, mask] + + # 5. Reconstruct Pseudo-Inverse (Electrostatic matrix) + # evecs @ diag(1/evals) @ evecs.T + electrostatic = (evecs_filtered / evals_filtered) @ evecs_filtered.T + + # Broadcast subtraction of the diagonal + electrostatic = electrostatic - electrostatic.diag() + + # 6. Compute statistics + # Note: dim=0 is operations along columns, dim=1 is operations along rows + electrostatic_encoding = torch.stack( + [ + electrostatic.min(dim=0)[0], + electrostatic.mean(dim=0), + electrostatic.std(dim=0), + electrostatic.min(dim=1)[0], + electrostatic.std(dim=1), + (DinvA * electrostatic).sum(dim=0), + (DinvA * electrostatic).sum(dim=1), + ], + dim=1, + ) + + # Corner case check + if ( + torch.all(electrostatic_encoding == 0) + and num_nodes > 2 + and list(remove_self_loops(edge_index)[0].cpu().shape) != [2, 0] + ): + raise ValueError("ElectrostaticPE is all zeros") + + if torch.any(torch.isnan(electrostatic_encoding)): + raise ValueError("ElectrostaticPE contains NaNs") + + return electrostatic_encoding + + def _compute_numpy( + self, edge_index: torch.Tensor, num_nodes: int + ) -> torch.Tensor: + """Compute ElectrostaticPE using the original CPU NumPy implementation. + + Parameters + ---------- + edge_index : torch.Tensor + Edge indices of the graph. + num_nodes : int + Number of nodes in the graph. + + Returns + ------- + torch.Tensor + Electrostatic positional encodings of shape ``[num_nodes, 7]``. + """ + from copy import deepcopy + + import numpy as np + from torch_geometric.utils import to_scipy_sparse_matrix + + device = edge_index.device if edge_index.size(1) == 0 or num_nodes <= 1: return torch.zeros(num_nodes, self.pe_dim, device=device) - # Normalized Laplacian edge_index_lap, edge_weight = get_laplacian( edge_index, normalization="sym", num_nodes=num_nodes ) @@ -103,13 +232,7 @@ def _compute_electrostatic_pe( A.fill_diagonal_(0) DinvA = Dinv.matmul(A) - # evals, evecs = torch.linalg.eigh(L) - # try: - # #evals, evecs = torch.linalg.eigh(L) - # except: - # IMDB-BINARY has some issue with scipy.sparse.linalg.eigsh deep in scipy library. evals, evecs = np.linalg.eigh(L.numpy()) - # back to torch evals = torch.from_numpy(evals) evecs = torch.from_numpy(evecs) @@ -123,29 +246,22 @@ def _compute_electrostatic_pe( electrostatic = electrostatic - electrostatic.diag() electrostatic_encoding = torch.stack( [ - electrostatic.min(dim=0)[0], # Min of Vi -> j - electrostatic.mean(dim=0), # Mean of Vi -> j - electrostatic.std(dim=0), # Std of Vi -> j - electrostatic.min(dim=1)[0], # Min of Vj -> i - electrostatic.std(dim=1), # Std of Vj -> i - (DinvA * electrostatic).sum( - dim=0 - ), # Mean of interaction on direct neighbour - (DinvA * electrostatic).sum( - dim=1 - ), # Mean of interaction from direct neighbour + electrostatic.min(dim=0)[0], + electrostatic.mean(dim=0), + electrostatic.std(dim=0), + electrostatic.min(dim=1)[0], + electrostatic.std(dim=1), + (DinvA * electrostatic).sum(dim=0), + (DinvA * electrostatic).sum(dim=1), ], dim=1, ) - # TODO: some corner case when num_nodes=2 on MUTAG - if torch.all(electrostatic_encoding == 0) and num_nodes > 2: - if list(remove_self_loops(edge_index)[0].cpu().shape) == [2, 0]: - # Case when there is no connectivity - pass - else: - raise ValueError("ElectrostaticPE is all zeros") + if ( + torch.all(electrostatic_encoding == 0) + and num_nodes > 2 + and list(remove_self_loops(edge_index)[0].cpu().shape) != [2, 0] + ): + raise ValueError("ElectrostaticPE is all zeros") - if torch.any(torch.isnan(electrostatic_encoding)): - raise ValueError("ElectrostaticPE contains NaNs") - return electrostatic_encoding.float() + return electrostatic_encoding.float().to(device) diff --git a/topobench/transforms/data_manipulations/hk_feature_encodings.py b/topobench/transforms/data_manipulations/hk_feature_encodings.py index da1187b60..764e246eb 100644 --- a/topobench/transforms/data_manipulations/hk_feature_encodings.py +++ b/topobench/transforms/data_manipulations/hk_feature_encodings.py @@ -1,9 +1,12 @@ -"""Heat Kernel feature Encoding (HKFE) Transform.""" +"""Heat Kernel feature Encoding (HKFE) Transform (Debug Version).""" + +import time import numpy as np import omegaconf import torch from scipy.sparse.linalg import expm_multiply +from scipy.special import iv from torch_geometric.data import Data from torch_geometric.transforms import BaseTransform from torch_geometric.utils import ( @@ -21,11 +24,18 @@ class HKFE(BaseTransform): kernel_param_HKFE : tuple of int Tuple specifying the start and end diffusion times for the heat kernel. concat_to_x : bool, optional - If True, concatenates the encodings with existing node features in - ``data.x``. If ``data.x`` is None, creates it. Default is True. + If True, concatenates encodings with existing node features in ``data.x``. + Default is True. aggregation : str, optional Aggregation function to reduce over the feature dimension. Options: "mean", "sum", "max", "min". Default is "mean". + method : str, optional + Computation method: "exact" or "approx". Default is "approx". + cheb_order : int, optional + The order of the Chebyshev polynomial. Default is 10. + debug : bool, optional + If True, runs both exact and approx methods, compares their outputs, + and prints the timing and error metrics. Default is False. **kwargs : dict Additional arguments (not used). """ @@ -37,17 +47,24 @@ def __init__( kernel_param_HKFE: tuple, concat_to_x: bool = True, aggregation: str = "mean", + method: str = "approx", + cheb_order: int = 10, + debug: bool = False, **kwargs, ): self.kernel_param_HKFE = kernel_param_HKFE self.concat_to_x = concat_to_x + if aggregation not in self._AGG_FN_MAP: - raise ValueError( - f"Unknown aggregation '{aggregation}'. " - f"Choose from: {list(self._AGG_FN_MAP.keys())}" - ) + raise ValueError(f"Unknown aggregation '{aggregation}'.") self.aggregation = aggregation - # Compute fe_dim from tuple/list or use directly if int + + if method not in ["exact", "approx"]: + raise ValueError("Method must be 'exact' or 'approx'.") + self.method = method + self.cheb_order = cheb_order + self.debug = debug + if ( isinstance(kernel_param_HKFE, (list, tuple)) or type(kernel_param_HKFE) is omegaconf.listconfig.ListConfig @@ -57,17 +74,17 @@ def __init__( self.fe_dim = kernel_param_HKFE def forward(self, data: Data) -> Data: - """Compute the Heat Kernel feature encodings for the input graph. + """Compute the HKFE for the input graph. Parameters ---------- - data : Data + data : torch_geometric.data.Data Input graph data object. Returns ------- - Data - Graph data object with heat kernel feature encodings added. + torch_geometric.data.Data + Graph data object with HKFE added to ``data.x`` or ``data.HKFE``. """ if data.x is None: raise ValueError( @@ -92,7 +109,7 @@ def _compute_hkfe( """Internal method to compute heat kernel feature encodings. Computes heat kernel diffusion at multiple time scales and aggregates - over input features to produce a fixed-dimension output (matching PSE pattern). + over input features to produce a fixed-dimension output. Parameters ---------- @@ -106,38 +123,213 @@ def _compute_hkfe( Returns ------- torch.Tensor - Heat Kernel feature encodings of shape [N, fe_dim]. + Heat Kernel feature encodings of shape ``[num_nodes, fe_dim]``. """ device = edge_index.device - hk_fe = [] if edge_index.size(1) == 0 or num_nodes <= 1: return torch.zeros(num_nodes, self.fe_dim, device=device) - # Normalized Laplacian + start, end = self.kernel_param_HKFE[0], self.kernel_param_HKFE[1] + kernel_times = np.geomspace(start, end, self.fe_dim) + + if len(kernel_times) == 0: + raise ValueError("Diffusion times are required for heat kernel") + edge_index_lap, edge_weight = get_laplacian( edge_index, normalization="sym", num_nodes=num_nodes ) + + if self.debug: + print("\n--- HKFE Debug Report ---") + print(f"Data device: {edge_index.device}") + # 1. Run Exact Method + t0 = time.time() + hk_fe_exact = self._compute_exact( + x, edge_index_lap, edge_weight, num_nodes, kernel_times, device + ) + t_exact = time.time() - t0 + print(f"Exact compute time: {t_exact:.4f}s") + + # 2. Run Approx Method + t0 = time.time() + hk_fe_approx = self._compute_approx( + x, edge_index_lap, edge_weight, num_nodes, kernel_times, device + ) + t_approx = time.time() - t0 + print( + f"Approx compute time: {t_approx:.4f}s (Cheb Order: {self.cheb_order})" + ) + + # 3. Compare Tensors (Before aggregation to see pure mathematical difference) + diff = torch.abs(hk_fe_exact - hk_fe_approx) + mean_diff = diff.mean().item() + max_diff = diff.max().item() + reldiff = diff / (torch.abs(hk_fe_exact) + 1e-8) + mean_reldiff = reldiff.mean().item() + + print(f"Speedup Factor: {t_exact / t_approx:.2f}x") + print(f"Mean Abs Error: {mean_diff:.6e}") + print(f"Max Abs Error: {max_diff:.6e}") + print(f"Mean Rel Error: {mean_reldiff:.6e}") + print("-------------------------\n") + + # Proceed with the method the user actually requested + hk_fe_raw = hk_fe_exact if self.method == "exact" else hk_fe_approx + else: + if self.method == "exact": + hk_fe_raw = self._compute_exact( + x, + edge_index_lap, + edge_weight, + num_nodes, + kernel_times, + device, + ) + else: + hk_fe_raw = self._compute_approx( + x, + edge_index_lap, + edge_weight, + num_nodes, + kernel_times, + device, + ) + + # Aggregate over features + agg_fn = getattr(hk_fe_raw, self._AGG_FN_MAP[self.aggregation]) + hk_fe = agg_fn(dim=-1) + + if torch.any(torch.isnan(hk_fe)): + raise ValueError("HKFE contains NaNs") + return hk_fe.float() + + def _compute_exact( + self, + x: torch.Tensor, + edge_index_lap: torch.Tensor, + edge_weight: torch.Tensor, + num_nodes: int, + kernel_times: np.ndarray, + device: torch.device, + ) -> torch.Tensor: + """Compute HKFE using original SciPy-based exact matrix exponential. + + Parameters + ---------- + x : torch.Tensor + Node features of the graph. + edge_index_lap : torch.Tensor + Laplacian edge indices. + edge_weight : torch.Tensor + Laplacian edge weights. + num_nodes : int + Number of nodes in the graph. + kernel_times : numpy.ndarray + Array of diffusion times. + device : torch.device + The device to perform computations on. + + Returns + ------- + torch.Tensor + Exact heat kernel feature encodings. + """ L = to_scipy_sparse_matrix( edge_index_lap, edge_weight, num_nodes ).astype(np.float64) + hk_fe = [] + x_np = x.detach().cpu().numpy().astype(np.float64) + for t in kernel_times: + x_t = expm_multiply((-float(t)) * L, x_np) + hk_fe.append(torch.from_numpy(x_t).float().to(device)) + return torch.stack(hk_fe, dim=1) + + def _compute_approx( + self, + x: torch.Tensor, + edge_index_lap: torch.Tensor, + edge_weight: torch.Tensor, + num_nodes: int, + kernel_times: np.ndarray, + device: torch.device, + ) -> torch.Tensor: + """Compute HKFE using fast Chebyshev polynomial approximation on GPU. + + Parameters + ---------- + x : torch.Tensor + Node features of the graph. + edge_index_lap : torch.Tensor + Laplacian edge indices. + edge_weight : torch.Tensor + Laplacian edge weights. + num_nodes : int + Number of nodes in the graph. + kernel_times : numpy.ndarray + Array of diffusion times. + device : torch.device + The device to perform computations on. + + Returns + ------- + torch.Tensor + Approximated heat kernel feature encodings. + """ + L = torch.sparse_coo_tensor( + edge_index_lap, + edge_weight.float(), + (num_nodes, num_nodes), + device=device, + ).coalesce() + + def apply_L_tilde(v): + """Apply the normalized Laplacian to a vector v. + + Parameters + ---------- + v : torch.Tensor + Input vector of shape [num_nodes, feature_dim]. + + Returns + ------- + torch.Tensor + Result of applying the normalized Laplacian to v. + """ + return torch.sparse.mm(L, v) - v + + T_x = [x] + if self.cheb_order > 0: + T_x.append(apply_L_tilde(x)) + + for _ in range(2, self.cheb_order + 1): + T_k = 2 * apply_L_tilde(T_x[-1]) - T_x[-2] + T_x.append(T_k) + + T_x = torch.stack(T_x, dim=0) + + # 1. Vectorize the CPU computation + # t_np shape: (T,) + t_np = kernel_times + # k_np shape: (K, 1) where K is cheb_order + 1 + k_np = np.arange(self.cheb_order + 1)[:, None] + + # SciPy's iv broadcasts automatically to shape (K, T) + bessel = iv(k_np, t_np[None, :]) - start, end = ( - self.kernel_param_HKFE[0], - self.kernel_param_HKFE[1], + # Calculate all coefficients at once: shape (K, T) + coeffs = 2 * np.exp(-t_np) * ((-1) ** k_np) * bessel + coeffs[0, :] /= 2 # The k=0 term doesn't have the 2x multiplier + + # 2. Single transfer to GPU + # Transpose to shape (T, K) for easier einsum matching + coeffs_tensor = torch.tensor( + coeffs.T, dtype=torch.float32, device=device ) - kernel_times = np.geomspace(start, end, self.fe_dim) - if len(kernel_times) == 0: - raise ValueError("Diffusion times are required for heat kernel") - x = x.detach().cpu().numpy().astype(np.float64) - for t in kernel_times: - x_t = expm_multiply((-float(t)) * L, x) - hk_fe.append(torch.from_numpy(x_t).float().to(device)) - hk_fe = torch.stack(hk_fe, dim=1) # [N, fe_dim, F] - # Aggregate over features to produce fixed-dimension output (like PSEs) - agg_fn = getattr(hk_fe, self._AGG_FN_MAP[self.aggregation]) - hk_fe = agg_fn(dim=-1) # [N, fe_dim] + # 3. Single operation on GPU! + # coeffs_tensor is [T, K] + # T_x is [K, N, F] + # We want output [N, T, F] to match your original torch.stack(..., dim=1) + hk_fe = torch.einsum("tk, knf -> ntf", coeffs_tensor, T_x) - if torch.any(torch.isnan(hk_fe)): - raise ValueError("HKFE contains NaNs") - return hk_fe.float() + return hk_fe diff --git a/topobench/transforms/data_manipulations/hkdiag_encodings.py b/topobench/transforms/data_manipulations/hkdiag_encodings.py index aff37e924..c2d04d1e1 100644 --- a/topobench/transforms/data_manipulations/hkdiag_encodings.py +++ b/topobench/transforms/data_manipulations/hkdiag_encodings.py @@ -1,6 +1,7 @@ -"""Laplacian Positional Encoding (LapPE) Transform.""" +"""Heat Kernel Diagonal Structural Encoding (HKdiagSE) Transform.""" + +import time -import numpy as np import omegaconf import torch import torch.nn.functional as F @@ -9,7 +10,6 @@ from torch_geometric.utils import ( get_laplacian, remove_self_loops, - to_scipy_sparse_matrix, ) @@ -17,21 +17,28 @@ class HKdiagSE(BaseTransform): r""" Heat Kernel Diagonal Structural Encoding (HKdiagSE) transform. - Diagonals of heat kernel diffusion. - Parameters ---------- kernel_param_HKdiagSE : tuple of int Tuple specifying the start and end diffusion times for the heat kernel. space_dim : int, optional - Estimated dimensionality of the space. Used to - correct the diffusion diagonal by a factor `t^(space_dim/2)`. In - euclidean space, this correction means that the height of the - gaussian stays constant across time, if `space_dim` is the dimension - of the euclidean space. Default is 0 (no correction). + Estimated dimensionality of the space. Used to correct the diffusion + diagonal by a factor `t^(space_dim/2)`. Default is 0 (no correction). + include_eigenvalues : bool, optional + If True, concatenates eigenvalues alongside eigenvectors. + Default is False. + include_first : bool, optional + If False, removes eigenvectors corresponding to (near-)zero eigenvalues. + Default is False. concat_to_x : bool, optional - If True, concatenates the encodings with existing node features in - ``data.x``. If ``data.x`` is None, creates it. Default is True. + If True, concatenates the encodings with existing node features. + Default is True. + method : str, optional + Computation method: "exact" (CPU NumPy + loop) or "fast" (GPU PyTorch + vectorized). + Default is "fast". + debug : bool, optional + If True, runs both methods and prints error/timing metrics. + Default is False. **kwargs : dict Additional arguments (not used). """ @@ -43,6 +50,8 @@ def __init__( include_eigenvalues: bool = False, include_first: bool = False, concat_to_x: bool = True, + method: str = "fast", + debug: bool = False, **kwargs, ): self.kernel_param_HKdiagSE = kernel_param_HKdiagSE @@ -50,26 +59,58 @@ def __init__( self.include_eigenvalues = include_eigenvalues self.include_first = include_first self.concat_to_x = concat_to_x + self.method = method + self.debug = debug self.pe_dim = ( kernel_param_HKdiagSE[1] - kernel_param_HKdiagSE[0] if type(kernel_param_HKdiagSE) is omegaconf.listconfig.ListConfig else kernel_param_HKdiagSE ) + if method not in ["exact", "fast"]: + raise ValueError("Method must be 'exact' or 'fast'.") + def forward(self, data: Data) -> Data: - """Compute the Laplacian positional encodings for the input graph. + """Compute the Heat Kernel Diagonal Structural Encodings for the input graph. Parameters ---------- - data : Data + data : torch_geometric.data.Data Input graph data object. Returns ------- - Data - Graph data object with Laplacian positional encodings added. + torch_geometric.data.Data + Graph data object with HKdiagSE positional encodings added. """ - pe = self._compute_hkdiag_se(data.edge_index, data.num_nodes) + if self.debug: + print("\n--- HKdiagSE Debug Report ---") + print(f"Data device: {data.edge_index.device}") + # Exact Method (CPU) + t0 = time.time() + pe_exact = self._compute_exact(data.edge_index, data.num_nodes) + t_exact = time.time() - t0 + print(f"Exact compute time: {t_exact:.4f}s") + + # Fast Method (GPU Vectorized) + t0 = time.time() + pe_fast = self._compute_fast(data.edge_index, data.num_nodes) + t_fast = time.time() - t0 + print(f"Fast compute time: {t_fast:.4f}s") + + # Compare + diff = torch.abs(pe_exact - pe_fast) + speedup = (t_exact / t_fast) if t_fast > 0 else float("inf") + print(f"Speedup Factor: {speedup:.2f}x") + print(f"Mean Abs Error: {diff.mean().item():.6e}") + print("---------------------------\n") + + pe = pe_exact if self.method == "exact" else pe_fast + else: + if self.method == "exact": + pe = self._compute_exact(data.edge_index, data.num_nodes) + else: + pe = self._compute_fast(data.edge_index, data.num_nodes) if self.concat_to_x: if data.x is None: @@ -81,10 +122,10 @@ def forward(self, data: Data) -> Data: return data - def _compute_hkdiag_se( + def _compute_fast( self, edge_index: torch.Tensor, num_nodes: int ) -> torch.Tensor: - """Internal method to compute heat kernel diagonal structural encodings. + """Compute HKdiagSE using an optimized pure-PyTorch implementation with a vectorized time loop. Parameters ---------- @@ -96,14 +137,99 @@ def _compute_hkdiag_se( Returns ------- torch.Tensor - Electrostatic positional encodings. + Heat Kernel Diagonal Structural Encodings of shape ``[num_nodes, pe_dim]``. """ device = edge_index.device if edge_index.size(1) == 0 or num_nodes <= 1: return torch.zeros(num_nodes, self.pe_dim, device=device) - # Normalized Laplacian + # 1. Create dense Laplacian directly on GPU + edge_index_lap, edge_weight = get_laplacian( + edge_index, normalization="sym", num_nodes=num_nodes + ) + L = torch.sparse_coo_tensor( + edge_index_lap, + edge_weight.float(), + (num_nodes, num_nodes), + device=device, + ).to_dense() + + # 2. Hardware-accelerated eigendecomposition + evals, evects = torch.linalg.eigh(L) + evects = F.normalize(evects, p=2.0, dim=0) + + # 3. Filter out zero eigenvalues + mask = evals >= 1e-8 + evals = evals[mask] + evects = evects[:, mask] + + start, end = ( + self.kernel_param_HKdiagSE[0], + self.kernel_param_HKdiagSE[1], + ) + + # 4. Vectorize the time loop + # t_tensor shape: [T] + t_tensor = torch.arange(start, end, dtype=torch.float32, device=device) + + if len(t_tensor) == 0: + raise ValueError("Diffusion times are required for heat kernel") + + # Exponent matrix: shape [T, E] (Time steps by Eigenvalues) + exp_term = torch.exp(-t_tensor.unsqueeze(1) * evals.unsqueeze(0)) + + # Squared eigenvectors (phi^2): shape [E, N] + eigvec_mul = (evects**2).T + + # Single matrix multiplication replaces the entire loop: [T, E] @ [E, N] -> [T, N] + hk_diag = exp_term @ eigvec_mul + + # Apply spatial correction factor: t^(space_dim/2) + if self.space_dim != 0: + correction = t_tensor ** (self.space_dim / 2.0) + hk_diag = hk_diag * correction.unsqueeze(1) + + # Transpose to return [N, T] + hk_diag = hk_diag.transpose(0, 1) + + # Corner case checking + if ( + (torch.all(hk_diag == 0)) + and (num_nodes > 2) + and list(remove_self_loops(edge_index)[0].cpu().shape) != [2, 0] + ): + raise ValueError("HKdiagSE is all zeros") + + if torch.any(torch.isnan(hk_diag)): + raise ValueError("HKdiagSE contains NaNs") + + return hk_diag + + def _compute_exact( + self, edge_index: torch.Tensor, num_nodes: int + ) -> torch.Tensor: + """Compute HKdiagSE using the original, un-optimized CPU NumPy implementation. + + Parameters + ---------- + edge_index : torch.Tensor + Edge indices of the graph. + num_nodes : int + Number of nodes in the graph. + + Returns + ------- + torch.Tensor + Heat Kernel Diagonal Structural Encodings of shape ``[num_nodes, pe_dim]``. + """ + import numpy as np + from torch_geometric.utils import to_scipy_sparse_matrix + + device = edge_index.device + if edge_index.size(1) == 0 or num_nodes <= 1: + return torch.zeros(num_nodes, self.pe_dim, device=device) + edge_index_lap, edge_weight = get_laplacian( edge_index, normalization="sym", num_nodes=num_nodes ) @@ -112,8 +238,8 @@ def _compute_hkdiag_se( ).astype(np.float64) evals, evects = np.linalg.eigh(L.toarray()) - evals = torch.from_numpy(evals) - evects = torch.from_numpy(evects) + evals = torch.from_numpy(evals).to(device) + evects = torch.from_numpy(evects).to(device) start, end = ( self.kernel_param_HKdiagSE[0], @@ -127,36 +253,27 @@ def _compute_hkdiag_se( hk_diag = [] evects = F.normalize(evects, p=2.0, dim=0) - # Remove eigenvalues == 0 from the computation of the heat kernel idx_remove = evals < 1e-8 evals = evals[~idx_remove] evects = evects[:, ~idx_remove] - # Change the shapes for the computations - evals = evals.unsqueeze(-1) # lambda_{i, ..., ...} - evects = evects.transpose(0, 1) # phi_{i,j}: i-th eigvec X j-th node + evals = evals.unsqueeze(-1) + evects = evects.transpose(0, 1) - # Compute the heat kernels diagonal only for each time eigvec_mul = evects**2 for t in kernel_times: - # sum_{i>0}(exp(-2 t lambda_i) * phi_{i, j} * phi_{i, j}) this_kernel = torch.sum( torch.exp(-t * evals) * eigvec_mul, dim=0, keepdim=False ) - - # Multiply by `t` to stabilize the values, since the gaussian height - # is proportional to `1/t` hk_diag.append(this_kernel * (t ** (self.space_dim / 2))) + hk_diag = torch.stack(hk_diag, dim=0).transpose(0, 1) - # TODO: some corner case when N=2 on MUTAG - if (torch.all(hk_diag == 0)) and (num_nodes > 2): - # Case when there is no connectivity - if list(remove_self_loops(edge_index)[0].cpu().shape) == [2, 0]: - pass - else: - raise ValueError("HKdiagSE is all zeros") + if ( + (torch.all(hk_diag == 0)) + and (num_nodes > 2) + and list(remove_self_loops(edge_index)[0].cpu().shape) != [2, 0] + ): + raise ValueError("HKdiagSE is all zeros") - if torch.any(torch.isnan(hk_diag)): - raise ValueError("HKdiagSE contains NaNs") - return hk_diag.float() + return hk_diag.float().to(device) diff --git a/topobench/transforms/data_manipulations/khop_feature_encodings.py b/topobench/transforms/data_manipulations/khop_feature_encodings.py index 520ff846d..ebc0c522d 100644 --- a/topobench/transforms/data_manipulations/khop_feature_encodings.py +++ b/topobench/transforms/data_manipulations/khop_feature_encodings.py @@ -1,9 +1,11 @@ """K-hop feature Encoding (KFE) for Hasse graphs Transform.""" +import time + import torch from torch_geometric.data import Data from torch_geometric.transforms import BaseTransform -from torch_geometric.utils import to_dense_adj +from torch_geometric.utils import degree, to_dense_adj class KHopFE(BaseTransform): @@ -12,14 +14,22 @@ class KHopFE(BaseTransform): Parameters ---------- - max_hop: int + max_hop : int The maximum hop neighbourhood. concat_to_x : bool, optional If True, concatenates the encodings with existing node features in - ``data.x``. If ``data.x`` is None, creates it. Default is True. + ``data.x``. If ``data.x`` is None, creates it. + Default is True. aggregation : str, optional Aggregation function to reduce over the feature dimension. - Options: "mean", "sum", "max", "min". Default is "mean". + Options: "mean", "sum", "max", "min". + Default is "mean". + method : str, optional + Computation method: "dense" or "sparse". + Default is "sparse". + debug : bool, optional + If True, runs both methods and prints error/timing metrics. + Default is False. **kwargs : dict Additional arguments (not used). """ @@ -31,12 +41,17 @@ def __init__( max_hop: int, concat_to_x: bool = True, aggregation: str = "mean", + method: str = "sparse", + debug: bool = False, **kwargs, ): self.concat_to_x = concat_to_x self.max_hop = ( max_hop - 1 ) # The 0-th hop is always the features themselves + self.method = method + self.debug = debug + if aggregation not in self._AGG_FN_MAP: raise ValueError( f"Unknown aggregation '{aggregation}'. " @@ -44,17 +59,20 @@ def __init__( ) self.aggregation = aggregation + if method not in ["dense", "sparse"]: + raise ValueError("Method must be 'dense' or 'sparse'.") + def forward(self, data: Data) -> Data: """Compute the K-hop feature encodings for the input graph. Parameters ---------- - data : Data + data : torch_geometric.data.Data Input graph data object. Returns ------- - Data + torch_geometric.data.Data Graph data object with K-hop feature encodings added. """ if data.x is None: @@ -79,9 +97,6 @@ def _compute_khopfe( ) -> torch.Tensor: """Internal method to compute K-hop feature encodings. - Propagates features through K hops and aggregates over input features - to produce a fixed-dimension output (matching PSE pattern). - Parameters ---------- x : torch.Tensor @@ -94,25 +109,173 @@ def _compute_khopfe( Returns ------- torch.Tensor - K-hop feature encodings of shape [N, max_hop]. + K-hop feature encodings of shape [num_nodes, max_hop, feature_dim]. """ device = edge_index.device x = x.to(device) - khop_fe = [] + if edge_index.size(1) == 0 or num_nodes <= 1: return torch.zeros(num_nodes, self.max_hop, device=device) + if self.debug: + print("\n--- KHopFE Debug Report ---") + is_cuda = device.type == "cuda" + print(f"Data device: {device}") + + # Helper function to track both time and peak GPU memory + def _track_execution(func): + if is_cuda: + torch.cuda.synchronize(device) + torch.cuda.reset_peak_memory_stats(device) + mem_start = torch.cuda.memory_allocated(device) + + t0 = time.time() + result = func(x, edge_index, num_nodes, device) + + if is_cuda: + torch.cuda.synchronize(device) + t_elapsed = time.time() - t0 + # Calculate peak memory used during the function call + mem_peak = torch.cuda.max_memory_allocated(device) + mem_used = mem_peak - mem_start + else: + t_elapsed = time.time() - t0 + mem_used = 0 + + return result, t_elapsed, mem_used + + # Exact (Dense) + fe_dense, t_dense, mem_dense = _track_execution( + self._compute_dense + ) + print(f"Dense compute time: {t_dense:.4f}s") + if is_cuda: + print(f"Dense peak memory: {mem_dense / (1024**2):.2f} MB") + + # Approx (Sparse) + fe_sparse, t_sparse, mem_sparse = _track_execution( + self._compute_sparse + ) + print(f"Sparse compute time: {t_sparse:.4f}s") + if is_cuda: + print(f"Sparse peak memory: {mem_sparse / (1024**2):.2f} MB") + + # Compare + diff = torch.abs(fe_dense - fe_sparse) + speedup = (t_dense / t_sparse) if t_sparse > 0 else float("inf") + print(f"\nSpeedup Factor (Time): {speedup:.2f}x") + + if is_cuda and mem_sparse > 0: + mem_ratio = mem_dense / mem_sparse + print( + f"Memory Factor (VRAM): {mem_ratio:.2f}x (Dense uses {mem_ratio:.1f}x more memory)" + ) + + print(f"Mean Abs Error: {diff.mean().item():.6e}") + print(f"Max Abs Error: {diff.max().item():.6e}") + print("---------------------------\n") + + fe_raw = fe_dense if self.method == "dense" else fe_sparse + else: + if self.method == "dense": + fe_raw = self._compute_dense(x, edge_index, num_nodes, device) + else: + fe_raw = self._compute_sparse(x, edge_index, num_nodes, device) + + # Aggregate over features + agg_fn = getattr(fe_raw, self._AGG_FN_MAP[self.aggregation]) + khop_fe = agg_fn(dim=-1) + + if torch.any(torch.isnan(khop_fe)): + raise ValueError("KHopFE contains NaNs") + + return khop_fe.float() + + def _compute_dense( + self, + x: torch.Tensor, + edge_index: torch.Tensor, + num_nodes: int, + device: torch.device, + ) -> torch.Tensor: + """Compute KHopFE using original dense adjacency matrices. + + Parameters + ---------- + x : torch.Tensor + Node features of the graph. + edge_index : torch.Tensor + Edge indices of the graph. + num_nodes : int + Number of nodes in the graph. + device : torch.device + The device to perform computations on. + + Returns + ------- + torch.Tensor + Dense computation of K-hop feature encodings. + """ + khop_fe = [] A = to_dense_adj(edge_index, max_num_nodes=num_nodes).squeeze(0) + # Symmetric norm adjacency matrix deg = A.sum(dim=1) deg_inv_sqrt = torch.diagflat(torch.pow(deg + 1e-8, -0.5)) A_norm = deg_inv_sqrt @ A @ deg_inv_sqrt - for _hop in range(self.max_hop): - x = A_norm @ x - khop_fe.append(x) - khop_fe = torch.stack(khop_fe, dim=1) # [N, max_hop, F] - # Aggregate over features to produce fixed-dimension output (like PSEs) - agg_fn = getattr(khop_fe, self._AGG_FN_MAP[self.aggregation]) - khop_fe = agg_fn(dim=-1) # [N, max_hop] - return khop_fe.float() + curr_x = x + for _ in range(self.max_hop): + curr_x = A_norm @ curr_x + khop_fe.append(curr_x) + + return torch.stack(khop_fe, dim=1) + + def _compute_sparse( + self, + x: torch.Tensor, + edge_index: torch.Tensor, + num_nodes: int, + device: torch.device, + ) -> torch.Tensor: + """Compute KHopFE using optimized pure PyTorch sparse tensors. + + Parameters + ---------- + x : torch.Tensor + Node features of the graph. + edge_index : torch.Tensor + Edge indices of the graph. + num_nodes : int + Number of nodes in the graph. + device : torch.device + The device to perform computations on. + + Returns + ------- + torch.Tensor + Sparse computation of K-hop feature encodings. + """ + khop_fe = [] + row, col = edge_index + + # 1. Compute node degrees using the row indices (out-degree) + deg = degree(row, num_nodes, dtype=torch.float32) + deg_inv_sqrt = torch.pow(deg + 1e-8, -0.5) + + # 2. Compute symmetric normalized edge weights: (D^-0.5)[i] * (D^-0.5)[j] + # Since A[i,j] is 1 for existing edges, the weight is just the product of the inverse sqrts + edge_weight = deg_inv_sqrt[row] * deg_inv_sqrt[col] + + # 3. Create the sparse symmetric normalized adjacency matrix + A_sparse = torch.sparse_coo_tensor( + edge_index, edge_weight, (num_nodes, num_nodes), device=device + ).coalesce() + + # 4. Iteratively propagate features via Sparse Matrix-Matrix multiplication + curr_x = x + for _ in range(self.max_hop): + curr_x = torch.sparse.mm(A_sparse, curr_x) + khop_fe.append(curr_x) + + return torch.stack(khop_fe, dim=1) diff --git a/topobench/transforms/data_manipulations/laplacian_encodings.py b/topobench/transforms/data_manipulations/laplacian_encodings.py index ba6f7cbf6..9bfff5f97 100644 --- a/topobench/transforms/data_manipulations/laplacian_encodings.py +++ b/topobench/transforms/data_manipulations/laplacian_encodings.py @@ -1,38 +1,44 @@ """Laplacian Positional Encoding (LapPE) Transform.""" +import time + import numpy as np import torch from scipy.sparse.linalg import eigsh from torch_geometric.data import Data from torch_geometric.transforms import BaseTransform -from torch_geometric.utils import get_laplacian, to_scipy_sparse_matrix +from torch_geometric.utils import degree, get_laplacian, to_scipy_sparse_matrix class LapPE(BaseTransform): r""" Laplacian Positional Encoding (LapPE) transform. - This computes the smallest eigenvectors of the normalized Laplacian - matrix and appends them as node features (structural encodings). - Optionally pads to a fixed dimension. - Parameters ---------- max_pe_dim : int Maximum number of eigenvectors to use (dimensionality of the encoding). include_eigenvalues : bool, optional If True, concatenates eigenvalues alongside eigenvectors. - Shape then becomes ``[num_nodes, 2 * max_pe_dim]``. Default is False. + Default is False. include_first : bool, optional - If False, removes eigenvectors corresponding to (near-)zero eigenvalues - (constant eigenvector in connected graphs). Default is False. + If False, removes eigenvectors corresponding to (near-)zero eigenvalues. + Default is False. concat_to_x : bool, optional - If True, concatenates the encodings with existing node features in - ``data.x``. If ``data.x`` is None, creates it. Default is True. + If True, concatenates the encodings with existing node features. + Default is True. eps : float, optional - Small value to avoid division by zero. Default is 1e-6. + Small value to avoid division by zero. + Default is 1e-6. tolerance : float, optional - Tolerance for the eigenvalue solver. Default is 0.001. + Tolerance for the eigenvalue solver. + Default is 0.001. + method : str, optional + Computation method: "exact" (SciPy CPU) or "gpu" (PyTorch GPU). + Default is "gpu". + debug : bool, optional + If True, runs both methods and prints error/timing metrics. + Default is False. **kwargs : dict Additional arguments (not used). """ @@ -45,6 +51,8 @@ def __init__( concat_to_x: bool = True, eps: float = 1e-6, tolerance: float = 0.001, + method: str = "gpu", + debug: bool = False, **kwargs, ): self.max_pe_dim = max_pe_dim @@ -53,21 +61,53 @@ def __init__( self.concat_to_x = concat_to_x self.eps = eps self.tolerance = tolerance + self.debug = debug + + if method not in ["exact", "gpu"]: + raise ValueError("Method must be 'exact' or 'gpu'.") + self.method = method def forward(self, data: Data) -> Data: """Compute the Laplacian positional encodings for the input graph. Parameters ---------- - data : Data + data : torch_geometric.data.Data Input graph data object. Returns ------- - Data + torch_geometric.data.Data Graph data object with Laplacian positional encodings added. """ - pe = self._compute_lap_pe(data.edge_index, data.num_nodes) + if self.debug: + print("\n--- LapPE Debug Report ---") + print(f"Data device: {data.edge_index.device}") + # Exact Method (SciPy CPU) + t0 = time.time() + pe_exact = self._compute_exact(data.edge_index, data.num_nodes) + t_exact = time.time() - t0 + print(f"Exact compute time: {t_exact:.4f}s") + + # Fast Method (PyTorch GPU) + t0 = time.time() + pe_gpu = self._compute_gpu(data.edge_index, data.num_nodes) + t_gpu = time.time() - t0 + print(f"Fast compute time: {t_gpu:.4f}s") + + # Compare Tensors + diff = torch.abs(pe_exact - pe_gpu) + speedup = (t_exact / t_gpu) if t_gpu > 0 else float("inf") + print(f"Speedup Factor: {speedup:.2f}x") + print(f"Mean Abs Error: {diff.mean().item():.6e}") + print("--------------------------\n") + + pe = pe_exact if self.method == "exact" else pe_gpu + else: + if self.method == "exact": + pe = self._compute_exact(data.edge_index, data.num_nodes) + else: + pe = self._compute_gpu(data.edge_index, data.num_nodes) if self.concat_to_x: if data.x is None: @@ -79,10 +119,71 @@ def forward(self, data: Data) -> Data: return data - def _compute_lap_pe( + def _fix_sign_ambiguity(self, evecs: torch.Tensor) -> torch.Tensor: + """Standardize eigenvector signs so the max absolute value is positive. + + Parameters + ---------- + evecs : torch.Tensor + Eigenvectors tensor of shape ``[num_nodes, max_pe_dim]``. + + Returns + ------- + torch.Tensor + Sign-corrected eigenvectors tensor. + """ + max_idxs = torch.argmax(torch.abs(evecs), dim=0) + signs = torch.sign(evecs[max_idxs, torch.arange(evecs.shape[1])]) + # Replace 0 signs with 1 to avoid zeroing out vectors + signs[signs == 0] = 1 + return evecs * signs + + def _pad_and_concat( + self, + evals: torch.Tensor, + evecs: torch.Tensor, + num_nodes: int, + device: torch.device, + ) -> torch.Tensor: + """Pad to max_pe_dim and optionally concatenate eigenvalues. + + Parameters + ---------- + evals : torch.Tensor + Eigenvalues tensor of shape ``[max_pe_dim]``. + evecs : torch.Tensor + Eigenvectors tensor of shape ``[num_nodes, max_pe_dim]``. + num_nodes : int + Number of nodes in the graph. + device : torch.device + The device to place the resulting tensor on. + + Returns + ------- + torch.Tensor + The padded and optionally concatenated positional encoding tensor. + """ + # Pad if fewer than max_pe_dim + pad_width = self.max_pe_dim - evecs.shape[1] + if pad_width > 0: + evecs = torch.nn.functional.pad( + evecs, (0, pad_width), mode="constant", value=0 + ) + evals = torch.nn.functional.pad( + evals, (0, pad_width), mode="constant", value=0 + ) + + pe = evecs + if self.include_eigenvalues: + eigvals_broadcast = evals.unsqueeze(0).repeat(num_nodes, 1) + pe = torch.cat([pe, eigvals_broadcast], dim=-1) + + return pe + + def _compute_exact( self, edge_index: torch.Tensor, num_nodes: int ) -> torch.Tensor: - """Internal method to compute Laplacian eigenvector encodings. + """Compute LapPE using original SciPy CPU Implementation. Parameters ---------- @@ -94,14 +195,16 @@ def _compute_lap_pe( Returns ------- torch.Tensor - Laplacian positional encodings. + Exact Laplacian positional encodings. """ device = edge_index.device - if edge_index.size(1) == 0 or num_nodes <= 1: - return torch.zeros(num_nodes, self.max_pe_dim, device=device) + return torch.zeros( + num_nodes, + self.max_pe_dim * (2 if self.include_eigenvalues else 1), + device=device, + ) - # Normalized Laplacian edge_index_lap, edge_weight = get_laplacian( edge_index, normalization="sym", num_nodes=num_nodes ) @@ -109,42 +212,122 @@ def _compute_lap_pe( edge_index_lap, edge_weight, num_nodes ).astype(np.float64) - k = min(self.max_pe_dim, max(1, num_nodes - 1)) + k = min( + self.max_pe_dim + (0 if self.include_first else 1), num_nodes - 1 + ) + k = max(1, k) try: evals, evecs = eigsh(L, k=k, which="SM", tol=self.tolerance) except Exception: evals, evecs = np.linalg.eigh(L.toarray()) - # Drop trivial eigenvectors if requested if not self.include_first: mask = evals > self.eps evals, evecs = evals[mask], evecs[:, mask] - # Take up to k evals, evecs = evals[: self.max_pe_dim], evecs[:, : self.max_pe_dim] - # Fix sign ambiguity - for i in range(evecs.shape[1]): - max_idx = np.argmax(np.abs(evecs[:, i])) - if evecs[max_idx, i] < 0: - evecs[:, i] *= -1 + # Convert to PyTorch for sign fixing and padding + evals = torch.from_numpy(evals).float().to(device) + evecs = torch.from_numpy(evecs).float().to(device) - # Pad if fewer than max_pe_dim - if evecs.shape[1] < self.max_pe_dim: - pad_width = self.max_pe_dim - evecs.shape[1] - evecs = np.pad(evecs, ((0, 0), (0, pad_width)), mode="constant") - evals = np.pad(evals, (0, pad_width), mode="constant") + evecs = self._fix_sign_ambiguity(evecs) + return self._pad_and_concat(evals, evecs, num_nodes, device) - pe = torch.from_numpy(evecs).to(dtype=torch.float32, device=device) + def _compute_gpu( + self, edge_index: torch.Tensor, num_nodes: int + ) -> torch.Tensor: + """Compute LapPE using gpu PyTorch GPU Implementation with Shift Trick. - if self.include_eigenvalues: - eigvals_broadcast = torch.from_numpy(evals).to( - dtype=torch.float32, device=device - ) - eigvals_broadcast = eigvals_broadcast.unsqueeze(0).repeat( - num_nodes, 1 + Parameters + ---------- + edge_index : torch.Tensor + Edge indices of the graph. + num_nodes : int + Number of nodes in the graph. + + Returns + ------- + torch.Tensor + Fast approximation of Laplacian positional encodings. + """ + device = edge_index.device + if edge_index.size(1) == 0 or num_nodes <= 1: + return torch.zeros( + num_nodes, + self.max_pe_dim * (2 if self.include_eigenvalues else 1), + device=device, ) - pe = torch.cat([pe, eigvals_broadcast], dim=-1) - return pe + # We need k + 1 if we are dropping the first eigenvalue + k_compute = min( + self.max_pe_dim + (0 if self.include_first else 1), num_nodes + ) + + # 1. Get exact Laplacian edge weights + edge_index_lap, edge_weight_lap = get_laplacian( + edge_index, normalization="sym", num_nodes=num_nodes + ) + + # 2. Decide solver based on graph size + if num_nodes < 128 or k_compute >= num_nodes: + # For small graphs, dense PyTorch GPU is instantaneously gpu and mathematically perfectly stable + L_dense = torch.sparse_coo_tensor( + edge_index_lap, + edge_weight_lap.float(), + (num_nodes, num_nodes), + device=device, + ).to_dense() + evals, evecs = torch.linalg.eigh(L_dense) + else: + try: + # 3. The Shift Trick: Find LARGEST eigenvalues of Adjacency (A = I - L) + # We extract the Adjacency matrix by negating the off-diagonal Laplacian elements + row, col = edge_index + deg = degree(col, num_nodes, dtype=torch.float32) + deg_inv_sqrt = deg.pow(-0.5) + deg_inv_sqrt.masked_fill_(deg_inv_sqrt == float("inf"), 0) + a_weight = deg_inv_sqrt[row] * deg_inv_sqrt[col] + + A_sym = torch.sparse_coo_tensor( + edge_index, a_weight, (num_nodes, num_nodes), device=device + ).coalesce() + + # Provide initial guess X to speed up convergence + X = torch.randn( + num_nodes, k_compute, dtype=torch.float32, device=device + ) + + # lobpcg computes largest eigenvalues natively + evals_A, evecs = torch.lobpcg( + A=A_sym, X=X, largest=True, tol=self.tolerance + ) + + # Convert back to Laplacian eigenvalues (L = I - A) + evals = 1.0 - evals_A + + # lobpcg returns descending order; we need ascending order (smallest first) + evals, indices = torch.sort(evals, descending=False) + evecs = evecs[:, indices] + + except Exception: + # If the sparse graph is highly ill-conditioned and lobpcg fails, fallback to dense GPU + L_dense = torch.sparse_coo_tensor( + edge_index_lap, + edge_weight_lap.float(), + (num_nodes, num_nodes), + device=device, + ).to_dense() + evals, evecs = torch.linalg.eigh(L_dense) + + # 4. Mask, Slice, and Format + if not self.include_first: + mask = evals > self.eps + evals, evecs = evals[mask], evecs[:, mask] + + evals = evals[: self.max_pe_dim] + evecs = evecs[:, : self.max_pe_dim] + + evecs = self._fix_sign_ambiguity(evecs) + return self._pad_and_concat(evals, evecs, num_nodes, device) diff --git a/topobench/transforms/data_manipulations/ppr_feature_encodings.py b/topobench/transforms/data_manipulations/ppr_feature_encodings.py index 5475c7729..b0987642a 100644 --- a/topobench/transforms/data_manipulations/ppr_feature_encodings.py +++ b/topobench/transforms/data_manipulations/ppr_feature_encodings.py @@ -1,36 +1,44 @@ """Personalized Page Rank Feature Encoding (PPRFE) Transform.""" +import time + import numpy as np import omegaconf import torch from scipy.linalg import inv from torch_geometric.data import Data from torch_geometric.transforms import BaseTransform -from torch_geometric.utils import to_dense_adj +from torch_geometric.utils import add_self_loops, degree, to_dense_adj class PPRFE(BaseTransform): r""" Personalized Page Rank Feature Encodings (PPRFE) transform. - Computes PPR diffusion of node features using the formula: - PPR = α(I - (1-α)Ã)^{-1} - where à = D^{-1/2} A^ D^{-1/2} is the normalized adjacency matrix - and A^ = A + I (adjacency with self-loops). - Parameters ---------- alpha_param_PPRFE : tuple of float Tuple specifying the start and end teleport probabilities (alpha values). - Values should be in (0, 1]. Higher alpha = more local, lower = more global. concat_to_x : bool, optional - If True, concatenates the encodings with existing node features in - ``data.x``. If ``data.x`` is None, creates it. Default is True. + If True, concatenates the encodings with existing node features. + Default is True. aggregation : str, optional Aggregation function to reduce over the feature dimension. - Options: "mean", "sum", "max", "min". Default is "mean". + Options: "mean", "sum", "max", "min". + Default is "mean". self_loop : bool, optional - If True, adds self-loops to the adjacency matrix. Default is True. + If True, adds self-loops to the adjacency matrix. + Default is True. + method : str, optional + Computation method: "exact" or "approx". + Default is "approx". + appnp_K : int, optional + Number of polynomial expansion terms (propagation steps) for the approx method. + Higher means more global information but slower. + Default is 20. + debug : bool, optional + If True, runs both methods and prints error/timing metrics. + Default is False. **kwargs : dict Additional arguments (not used). """ @@ -43,23 +51,29 @@ def __init__( concat_to_x: bool = True, aggregation: str = "mean", self_loop: bool = True, + method: str = "approx", + appnp_K: int = 20, + debug: bool = False, **kwargs, ): self.alpha_param_PPRFE = alpha_param_PPRFE self.concat_to_x = concat_to_x self.self_loop = self_loop + if aggregation not in self._AGG_FN_MAP: - raise ValueError( - f"Unknown aggregation '{aggregation}'. " - f"Choose from: {list(self._AGG_FN_MAP.keys())}" - ) + raise ValueError(f"Unknown aggregation '{aggregation}'.") self.aggregation = aggregation - # Compute fe_dim from tuple/list + + if method not in ["exact", "approx"]: + raise ValueError("Method must be 'exact' or 'approx'.") + self.method = method + self.appnp_K = appnp_K + self.debug = debug + if ( isinstance(alpha_param_PPRFE, (list, tuple)) or type(alpha_param_PPRFE) is omegaconf.listconfig.ListConfig ): - # Number of alpha values to use self.fe_dim = alpha_param_PPRFE[1] else: self.fe_dim = alpha_param_PPRFE @@ -69,12 +83,12 @@ def forward(self, data: Data) -> Data: Parameters ---------- - data : Data + data : torch_geometric.data.Data Input graph data object. Returns ------- - Data + torch_geometric.data.Data Graph data object with PPR feature encodings added. """ if data.x is None: @@ -111,57 +125,198 @@ def _compute_pprfe( Returns ------- torch.Tensor - PPR feature encodings of shape [N, fe_dim]. + PPR feature encodings of shape ``[num_nodes, fe_dim]``. """ device = edge_index.device if edge_index.size(1) == 0 or num_nodes <= 1: return torch.zeros(num_nodes, self.fe_dim, device=device) - # Convert to dense adjacency matrix + start, num_alphas = ( + self.alpha_param_PPRFE[0], + self.alpha_param_PPRFE[1], + ) + alpha_values = np.linspace(start, 0.9, num_alphas) + + if self.debug: + print("\n--- PPRFE Debug Report ---") + print(f"Data device: {device}") + # Exact + t0 = time.time() + fe_exact = self._compute_exact( + x, edge_index, num_nodes, alpha_values, device + ) + t_exact = time.time() - t0 + print(f"Exact compute time: {t_exact:.4f}s") + + # Approx + t0 = time.time() + fe_approx = self._compute_approx( + x, edge_index, num_nodes, alpha_values, device + ) + t_approx = time.time() - t0 + print( + f"Approx compute time: {t_approx:.4f}s (Polynomial Order K: {self.appnp_K})" + ) + + # Compare + diff = torch.abs(fe_exact - fe_approx) + reldiff = diff / (torch.abs(fe_exact) + 1e-8) + mean_reldiff = reldiff.mean().item() + + print(f"Speedup Factor: {t_exact / t_approx:.2f}x") + print(f"Mean Abs Error: {diff.mean().item():.6e}") + print(f"Max Abs Error: {diff.max().item():.6e}") + print(f"Mean Rel Error: {mean_reldiff:.6e}") + print("--------------------------\n") + + fe_raw = fe_exact if self.method == "exact" else fe_approx + else: + if self.method == "exact": + fe_raw = self._compute_exact( + x, edge_index, num_nodes, alpha_values, device + ) + else: + fe_raw = self._compute_approx( + x, edge_index, num_nodes, alpha_values, device + ) + + # Aggregate over features + agg_fn = getattr(fe_raw, self._AGG_FN_MAP[self.aggregation]) + fe_agg = agg_fn(dim=-1) + + if torch.any(torch.isnan(fe_agg)): + raise ValueError("PPRFE contains NaNs") + + return fe_agg.float() + + def _compute_exact( + self, + x: torch.Tensor, + edge_index: torch.Tensor, + num_nodes: int, + alpha_values: np.ndarray, + device: torch.device, + ) -> torch.Tensor: + """Compute exact O(N^3) dense matrix inversion method for PPR. + + Parameters + ---------- + x : torch.Tensor + Node features of the graph. + edge_index : torch.Tensor + Edge indices of the graph. + num_nodes : int + Number of nodes in the graph. + alpha_values : numpy.ndarray + Array of teleport probabilities. + device : torch.device + The device to perform computations on. + + Returns + ------- + torch.Tensor + Exact PPR feature encodings. + """ adj = to_dense_adj(edge_index, max_num_nodes=num_nodes)[0] adj_np = adj.cpu().numpy().astype(np.float64) - # Add self-loops: A^ = A + I if self.self_loop: adj_np = adj_np + np.eye(num_nodes) - # Compute degree matrix D deg = np.sum(adj_np, axis=1) - # Handle isolated nodes deg_safe = np.where(deg > 0, deg, 1.0) deg_inv_sqrt = np.diag(1.0 / np.sqrt(deg_safe)) - - # Normalized adjacency: à = D^{-1/2} A^ D^{-1/2} adj_norm = deg_inv_sqrt @ adj_np @ deg_inv_sqrt - # Generate alpha values (teleport probabilities) - start, num_alphas = ( - self.alpha_param_PPRFE[0], - self.alpha_param_PPRFE[1], - ) - # Use linear spacing for alpha in (0, 1] - # Start from a small alpha (more global) to larger alpha (more local) - alpha_values = np.linspace(start, 0.9, num_alphas) - x_np = x.detach().cpu().numpy().astype(np.float64) ppr_fe = [] - identity = np.eye(num_nodes) + for alpha in alpha_values: - # PPR = α(I - (1-α)Ã)^{-1} ppr_matrix = alpha * inv(identity - (1 - alpha) * adj_norm) - # Diffuse features x_ppr = ppr_matrix @ x_np ppr_fe.append(torch.from_numpy(x_ppr).float().to(device)) - ppr_fe = torch.stack(ppr_fe, dim=1) # [N, fe_dim, F] + return torch.stack(ppr_fe, dim=1) - # Aggregate over features to produce fixed-dimension output - agg_fn = getattr(ppr_fe, self._AGG_FN_MAP[self.aggregation]) - ppr_fe = agg_fn(dim=-1) # [N, fe_dim] + def _compute_approx( + self, + x: torch.Tensor, + edge_index: torch.Tensor, + num_nodes: int, + alpha_values: np.ndarray, + device: torch.device, + ) -> torch.Tensor: + """Compute fast APPNP polynomial approximation using sparse matrix multiplication. - if torch.any(torch.isnan(ppr_fe)): - raise ValueError("PPRFE contains NaNs") + Parameters + ---------- + x : torch.Tensor + Node features of the graph. + edge_index : torch.Tensor + Edge indices of the graph. + num_nodes : int + Number of nodes in the graph. + alpha_values : numpy.ndarray + Array of teleport probabilities. + device : torch.device + The device to perform computations on. + + Returns + ------- + torch.Tensor + Approximated PPR feature encodings. + """ + # 1. Add self loops if needed + edge_weight = torch.ones( + edge_index.size(1), dtype=torch.float32, device=device + ) + if self.self_loop: + edge_index, edge_weight = add_self_loops( + edge_index, edge_weight, fill_value=1.0, num_nodes=num_nodes + ) + + # 2. Compute symmetric degree normalization: D^{-1/2} A D^{-1/2} + row, col = edge_index + deg = degree(col, num_nodes, dtype=torch.float32) + deg_inv_sqrt = deg.pow(-0.5) + deg_inv_sqrt.masked_fill_(deg_inv_sqrt == float("inf"), 0) + edge_weight = deg_inv_sqrt[row] * edge_weight * deg_inv_sqrt[col] + + # 3. Create sparse adjacency tensor + A_tilde = torch.sparse_coo_tensor( + edge_index, edge_weight, (num_nodes, num_nodes), device=device + ).coalesce() + + # 4. Precompute the graph diffusions T_k = A_tilde^k * X + T_x = [x] + for _ in range(self.appnp_K): + T_x.append(torch.sparse.mm(A_tilde, T_x[-1])) + + # Stack to shape [K+1, N, F] + T_x_tensor = torch.stack(T_x, dim=0) + + # 5. Vectorized alpha coefficients + # Shape alpha_tensor: [num_alphas] + alpha_tensor = torch.tensor( + alpha_values, dtype=torch.float32, device=device + ) + # Shape k_tensor: [K+1] + k_tensor = torch.arange( + self.appnp_K + 1, dtype=torch.float32, device=device + ) + + # Calculate coefficients: alpha * (1 - alpha)^k + # Expand dims to calculate outer product-like matrix. Result shape: [num_alphas, K+1] + alphas_expanded = alpha_tensor.unsqueeze(1) + ks_expanded = k_tensor.unsqueeze(0) + coeffs = alphas_expanded * (1.0 - alphas_expanded) ** ks_expanded + + # 6. Apply coefficients to diffusions via Einsum + # coeffs: [num_alphas, K+1] -> 'ak' + # T_x_tensor: [K+1, N, F] -> 'knf' + # Output: [N, num_alphas, F] -> 'nak' + ppr_fe = torch.einsum("ak, knf -> naf", coeffs, T_x_tensor) - return ppr_fe.float() + return ppr_fe diff --git a/topobench/transforms/data_manipulations/random_walk_encodings.py b/topobench/transforms/data_manipulations/random_walk_encodings.py index 25388b4da..622209eea 100644 --- a/topobench/transforms/data_manipulations/random_walk_encodings.py +++ b/topobench/transforms/data_manipulations/random_walk_encodings.py @@ -1,5 +1,7 @@ """Random Walk Structural Encodings (RWSE) Transform.""" +import time + import torch from torch_geometric.data import Data from torch_geometric.transforms import BaseTransform @@ -9,40 +11,132 @@ class RWSE(BaseTransform): r"""Random Walk Structural Encoding (RWSE) transform. - Computes return probabilities of random walks of length 1..K - for each node in the graph, and appends them as structural - encodings to node features. - Parameters ---------- max_pe_dim : int Maximum walk length (number of RWSE dimensions). concat_to_x : bool, optional - If True, concatenates the encodings with existing node - features in ``data.x``. If ``data.x`` is None, creates it. + If True, concatenates the encodings with existing node features. Default is True. + method : str, optional + Computation method: "dense", "sparse", or "batched". + "dense" uses standard matrix multiplication (Memory intensive). + "sparse" uses pure sparse matrix multiplication (Fastest, moderate memory). + "batched" uses indicator diffusion (Memory-bounded, slightly slower). + Default is "sparse". + batch_size : int, optional + Number of nodes to process simultaneously when using the "batched" method. + Lower values use less memory but take slightly longer. Default is 2048. + debug : bool, optional + If True, runs all methods, catches OOM errors, and prints a detailed + timing and peak VRAM memory footprint report. Default is False. **kwargs : dict Additional arguments (not used). """ - def __init__(self, max_pe_dim: int, concat_to_x: bool = True, **kwargs): + def __init__( + self, + max_pe_dim: int, + concat_to_x: bool = True, + method: str = "batched", + batch_size: int = 128, + debug: bool = False, + **kwargs, + ): self.max_pe_dim = max_pe_dim self.concat_to_x = concat_to_x + self.batch_size = batch_size + self.debug = debug + + if method not in ["dense", "sparse", "batched"]: + raise ValueError("Method must be 'dense', 'sparse', or 'batched'.") + self.method = method def forward(self, data: Data) -> Data: """Compute the RWSE for the input graph. Parameters ---------- - data : Data + data : torch_geometric.data.Data Input graph data object. Returns ------- - Data - Graph data object with RWSE added. + torch_geometric.data.Data + Graph data object with RWSE added to ``data.x`` or ``data.RWSE``. """ - pe = self._compute_rwse(data.edge_index, data.num_nodes) + if self.debug: + print("\n--- RWSE Debug Report ---") + print(f"Data device: {data.edge_index.device}") + # 1. Dense Method + try: + t0 = time.time() + pe_dense, t_dense, mem_dense = self._profile_method( + self._compute_dense, data.edge_index, data.num_nodes + ) + t_dense_total = time.time() - t0 + dense_status = f"{t_dense_total:.4f}s | {mem_dense:.2f} MB" + except RuntimeError as e: + if "out of memory" in str(e).lower(): + pe_dense = None + dense_status = "OOM (Out Of Memory) 💥" + else: + raise e + + # 2. Sparse Method + try: + t0 = time.time() + pe_sparse, t_sparse, mem_sparse = self._profile_method( + self._compute_sparse, data.edge_index, data.num_nodes + ) + t_sparse_total = time.time() - t0 + sparse_status = f"{t_sparse_total:.4f}s | {mem_sparse:.2f} MB" + except RuntimeError as e: + if "out of memory" in str(e).lower(): + pe_sparse = None + sparse_status = "OOM (Out Of Memory) 💥" + else: + raise e + + # 3. Batched Method + t0 = time.time() + pe_batched, t_batched, mem_batched = self._profile_method( + self._compute_batched, data.edge_index, data.num_nodes + ) + t_batched_total = time.time() - t0 + batched_status = f"{t_batched_total:.4f}s | {mem_batched:.2f} MB" + + # Print Report + print(f"{'Method':<10} | {'Status (Time | Peak VRAM)':<30}") + print("-" * 45) + print(f"{'Dense':<10} | {dense_status:<30}") + print(f"{'Sparse':<10} | {sparse_status:<30}") + print(f"{'Batched':<10} | {batched_status:<30}") + print("-" * 45) + + # Comparisons + if pe_dense is not None and pe_sparse is not None: + diff_ds = torch.abs(pe_dense - pe_sparse).max().item() + print(f"Max Abs Error (Dense vs Sparse): {diff_ds:.6e}") + if pe_sparse is not None and pe_batched is not None: + diff_sb = torch.abs(pe_sparse - pe_batched).max().item() + print(f"Max Abs Error (Sparse vs Batched): {diff_sb:.6e}") + print("-" * 45 + "\n") + + # Select output based on requested method + if self.method == "dense" and pe_dense is not None: + pe = pe_dense + elif self.method == "batched": + pe = pe_batched + else: + pe = pe_sparse + else: + if self.method == "dense": + pe = self._compute_dense(data.edge_index, data.num_nodes) + elif self.method == "batched": + pe = self._compute_batched(data.edge_index, data.num_nodes) + else: + pe = self._compute_sparse(data.edge_index, data.num_nodes) if self.concat_to_x: if data.x is None: @@ -54,44 +148,210 @@ def forward(self, data: Data) -> Data: return data - def _compute_rwse( + def _profile_method(self, func, edge_index: torch.Tensor, num_nodes: int): + """Helper method to profile execution time and memory (CPU or GPU). + + Parameters + ---------- + func : callable + The computation method to profile. + edge_index : torch.Tensor + Edge indices of the graph. + num_nodes : int + Number of nodes in the graph. + + Returns + ------- + tuple + A tuple containing (result_tensor, time_elapsed_seconds, peak_memory_mb). + """ + import tracemalloc + + device = edge_index.device + is_cuda = device.type == "cuda" + + if is_cuda: + # --- GPU Memory Tracking --- + torch.cuda.synchronize(device) + # Record the baseline memory before the function starts + start_mem = torch.cuda.memory_allocated(device) + torch.cuda.reset_peak_memory_stats(device) + + t0 = time.time() + pe = func(edge_index, num_nodes) + torch.cuda.synchronize(device) + t_elapsed = time.time() - t0 + + # Get the peak memory reached during the function + peak_mem = torch.cuda.max_memory_allocated(device) + + # The actual footprint of this method is the Peak minus the Baseline + mem_mb = (peak_mem - start_mem) / (1024 * 1024) + else: + # --- CPU Memory Tracking --- + tracemalloc.start() + + t0 = time.time() + pe = func(edge_index, num_nodes) + t_elapsed = time.time() - t0 + + # tracemalloc returns (current_memory, peak_memory) in bytes + _, peak_bytes = tracemalloc.get_traced_memory() + tracemalloc.stop() + + mem_mb = peak_bytes / (1024 * 1024) + + return pe, t_elapsed, mem_mb + + def _compute_dense( self, edge_index: torch.Tensor, num_nodes: int ) -> torch.Tensor: - """Internal method to compute RWSE return probabilities. + """Compute RWSE using original dense matrix multiplication. Parameters ---------- edge_index : torch.Tensor - Edge indices of the graph. + Edge indices of the graph of shape ``[2, num_edges]``. num_nodes : int Number of nodes in the graph. Returns ------- torch.Tensor - RWSE return probabilities. + RWSE return probabilities of shape ``[num_nodes, max_pe_dim]``. """ device = edge_index.device - if edge_index.numel() == 0 or num_nodes <= 1: return torch.zeros(num_nodes, self.max_pe_dim, device=device) - # Degree and adjacency deg = degree(edge_index[0], num_nodes=num_nodes).float().to(device) deg = torch.where(deg == 0, torch.ones_like(deg), deg) adj = torch.zeros(num_nodes, num_nodes, device=device) adj[edge_index[0], edge_index[1]] = 1.0 - # Transition matrix P = adj / deg.unsqueeze(-1) - - # RWSE features rwse = torch.zeros(num_nodes, self.max_pe_dim, device=device) P_power = torch.eye(num_nodes, device=device) for k in range(1, self.max_pe_dim + 1): P_power = P_power @ P - rwse[:, k - 1] = P_power.diag() # return probs + rwse[:, k - 1] = P_power.diag() return rwse.float() + + def _compute_sparse( + self, edge_index: torch.Tensor, num_nodes: int + ) -> torch.Tensor: + """Compute RWSE using optimized PyTorch sparse matrix multiplication. + + Parameters + ---------- + edge_index : torch.Tensor + Edge indices of the graph of shape ``[2, num_edges]``. + num_nodes : int + Number of nodes in the graph. + + Returns + ------- + torch.Tensor + RWSE return probabilities of shape ``[num_nodes, max_pe_dim]``. + """ + device = edge_index.device + if edge_index.numel() == 0 or num_nodes <= 1: + return torch.zeros(num_nodes, self.max_pe_dim, device=device) + + row, col = edge_index + + # 1. Compute Out-Degree + deg = degree(row, num_nodes=num_nodes, dtype=torch.float32) + deg_inv = 1.0 / deg.clamp_(min=1.0) + + # 2. Transition probabilities: P_{i,j} = 1 / deg(i) + edge_weight = deg_inv[row] + + # 3. Create Sparse Transition Matrix P + P = torch.sparse_coo_tensor( + edge_index, edge_weight, (num_nodes, num_nodes), device=device + ).coalesce() + + rwse = [] + Pk = P + + # Pre-allocate a zero tensor to avoid re-allocating memory inside the loop + pe_k = torch.zeros(num_nodes, device=device) + + for _ in range(self.max_pe_dim): + # 1. Grab coordinates and values + row, col = Pk.indices() + val = Pk.values() + + # 2. Find the diagonal elements (where row index == col index) + mask = row == col + + # 3. Drop them into the pre-allocated zero tensor and save + pe_k.zero_() # Reset the tensor inplace + pe_k.scatter_(0, row[mask], val[mask]) + rwse.append(pe_k.clone()) # Clone to save this step's state + + # 4. Advance the random walk + Pk = torch.sparse.mm(Pk, P) + + return torch.stack(rwse, dim=1) + + def _compute_batched( + self, edge_index: torch.Tensor, num_nodes: int + ) -> torch.Tensor: + """Compute RWSE using memory-bounded batched indicator diffusion. + + Parameters + ---------- + edge_index : torch.Tensor + Edge indices of the graph of shape ``[2, num_edges]``. + num_nodes : int + Number of nodes in the graph. + + Returns + ------- + torch.Tensor + RWSE return probabilities of shape ``[num_nodes, max_pe_dim]``. + """ + device = edge_index.device + if edge_index.numel() == 0 or num_nodes <= 1: + return torch.zeros(num_nodes, self.max_pe_dim, device=device) + + row, col = edge_index + + # 1. Compute Out-Degree and Edge Weights + deg = degree(row, num_nodes=num_nodes, dtype=torch.float32) + deg_inv = 1.0 / deg.clamp_(min=1.0) + edge_weight = deg_inv[row] + + # 2. Create Sparse Transition Matrix P + P = torch.sparse_coo_tensor( + edge_index, edge_weight, (num_nodes, num_nodes), device=device + ).coalesce() + + rwse = torch.zeros(num_nodes, self.max_pe_dim, device=device) + + # 3. Process nodes in strict memory-bounded batches + for start_idx in range(0, num_nodes, self.batch_size): + end_idx = min(start_idx + self.batch_size, num_nodes) + current_batch_size = end_idx - start_idx + + # Create an indicator matrix for this specific batch: Shape [N, B] + X = torch.zeros(num_nodes, current_batch_size, device=device) + batch_nodes = torch.arange(start_idx, end_idx, device=device) + batch_indices = torch.arange(current_batch_size, device=device) + X[batch_nodes, batch_indices] = 1.0 + + # Diffuse the features K times + for k in range(self.max_pe_dim): + # Matrix-Vector multiplication: [N, N] @ [N, B] -> [N, B] + X = torch.sparse.mm(P, X) + + # Extract the diagonal equivalent for this batch + return_probs = X[batch_nodes, batch_indices] + rwse[start_idx:end_idx, k] = return_probs + + return rwse diff --git a/topobench/utils/config_resolvers.py b/topobench/utils/config_resolvers.py index 5ce3cad75..3886a76c6 100644 --- a/topobench/utils/config_resolvers.py +++ b/topobench/utils/config_resolvers.py @@ -906,7 +906,7 @@ def get_list_element(list, index): return list[index] -def infer_in_khop_feature_dim(dataset_in_channels, max_hop): +def infer_in_khop_feature_dim(dataset_in_channels, max_hop, complex_dim=None): r"""Infer the dimension of the feature vector in the SANN k-hop model. Parameters @@ -915,12 +915,19 @@ def infer_in_khop_feature_dim(dataset_in_channels, max_hop): 1D array of input channels for the dataset. max_hop : int Maximum hop distance. + complex_dim : int, optional + Number of cell ranks processed by the transform. When provided, + ``dataset_in_channels`` is truncated to this length so the + recursive formula only considers ranks that actually appear in + the k-hop feature computation. Returns ------- int : Dimension of the feature vector in the SANN k-hop model. """ + if complex_dim is not None: + dataset_in_channels = list(dataset_in_channels)[:complex_dim] def compute_recursive_sequence(initial_values, time_steps): """Compute the sequence D_k^(t) based on the given recursive formula. @@ -1056,7 +1063,9 @@ def set_preserve_edge_attr(model_name, default=True): bool Default if the model can preserve edge attributes, False otherwise. """ - if model_name in ["sann", "hopse_m", "hopse_g"]: + if model_name in ["hopse_m", "hopse_g"]: + return True + elif model_name in ["sann"]: return False else: return default diff --git a/tutorials/count_params.ipynb b/tutorials/count_params.ipynb index da017eff3..40d59523f 100644 --- a/tutorials/count_params.ipynb +++ b/tutorials/count_params.ipynb @@ -7,8 +7,80 @@ "metadata": {}, "outputs": [], "source": [ + "import argparse\n", + "import re\n", + "import shlex\n", + "import itertools\n", + "import pandas as pd\n", + "import hydra\n", + "from hydra.core.global_hydra import GlobalHydra\n", + "import torch\n", + "import time \n", + "import datetime \n", + "\n", + "# Import run to trigger OmegaConf resolvers\n", + "import topobench.run as tb_run \n", + "from topobench.data.preprocessor import PreProcessor\n", + "from topobench.dataloader import TBDataloader\n", + "\n", + "def count_detailed_parameters(model: torch.nn.Module, only_trainable: bool = True):\n", + " \"\"\"Counts parameters and groups them by their top-level module name.\"\"\"\n", + " total_params = 0\n", + " breakdown = {\"Backbone\": 0, \"Feature_Encoder\": 0, \"Readout\": 0, \"Other\": 0}\n", + "\n", + " for name, p in model.named_parameters():\n", + " if only_trainable and not p.requires_grad:\n", + " continue\n", + " \n", + " num = p.numel()\n", + " total_params += num\n", + " \n", + " if \"backbone\" in name and \"wrapper\" not in name:\n", + " breakdown[\"Backbone\"] += num\n", + " elif \"feature_encoder\" in name:\n", + " breakdown[\"Feature_Encoder\"] += num\n", + " elif \"readout\" in name:\n", + " breakdown[\"Readout\"] += num\n", + " else:\n", + " breakdown[\"Other\"] += num\n", + " \n", + " return total_params, breakdown\n", + "\n", + "def get_override_mapping(content: str) -> dict[str, str]:\n", + " \"\"\"Parses the pipe-separated SWEEP_CONFIG block from the bash script.\"\"\"\n", + " \n", + " block_pattern = r\"SWEEP_CONFIG=\\((.*?)\\n\\s*\\)\"\n", + " block_match = re.search(block_pattern, content, re.DOTALL)\n", + " \n", + " if not block_match:\n", + " raise ValueError(\"Could not find the SWEEP_CONFIG block. Check if it ends with a ')' on a new line.\")\n", + " \n", + " sweep_body = block_match.group(1)\n", + " mapping_pattern = r\"\\\"[^|]*\\|([^|]+)\\|\\$\\{([^|*\\[]+)\"\n", + " pairs = re.findall(mapping_pattern, sweep_body)\n", + "\n", + " return {bash_arr.strip(): hydra_key.strip() for hydra_key, bash_arr in pairs}\n", + "\n", + "def parse_bash_array(content: str, array_name: str, mapping: dict) -> tuple[list[str], str]:\n", + " \"\"\"Extracts elements from a bash array and its corresponding Hydra override string.\"\"\"\n", + " pattern = rf\"{array_name}=\\((.*?)\\)\"\n", + " match = re.search(pattern, content, re.DOTALL)\n", + " \n", + " if not match:\n", + " raise ValueError(f\"Could not find array '{array_name}' in the script.\")\n", + " \n", + " items = shlex.split(match.group(1), comments=True)\n", + " override_key = mapping.get(array_name, array_name)\n", + " \n", + " return items, override_key\n", + "\n", + "def get_hydra_val(val: str) -> str:\n", + " \"\"\"Extracts the hydra value from the 'alias::hydra_value' syntax.\"\"\"\n", + " return val.split('::', 1)[1] if '::' in val else val\n", + "\n", "def main(sh_path: str, sweep_arrays: list[str] = None, out_path: str = None):\n", - " # 1. Set default iteration parameters if none are provided\n", + " start_time_overall = time.time()\n", + " \n", " if sweep_arrays is None:\n", " sweep_arrays = [\"models\", \"neighborhoods\", \"encodings\", \"num_layers\", \"hidden_channels\"]\n", " \n", @@ -17,10 +89,8 @@ " with open(sh_path, 'r') as f:\n", " content = f.read()\n", " \n", - " # 2. Extract the Hydra override keys mapping from SWEEP_CONFIG\n", " mapping = get_override_mapping(content)\n", " \n", - " # 3. Dynamically extract values and keys for the target sweep arrays\n", " sweep_data = {}\n", " for arr in sweep_arrays:\n", " vals_raw, override_key = parse_bash_array(content, arr, mapping)\n", @@ -29,7 +99,6 @@ " \"values\": [get_hydra_val(v) for v in vals_raw]\n", " }\n", " \n", - " # 4. Automatically fix any non-swept parameters to their first value\n", " fixed_params = {}\n", " fixed_overrides = []\n", " \n", @@ -43,11 +112,9 @@ " fixed_overrides.append(f\"{override_key}={first_val}\")\n", " print(f\"Fixed '{arr_name}': {first_val}\")\n", " except Exception as e:\n", - " # Catch variables defined in mapping but missing/commented out in the bash script\n", " print(f\"Warning: Could not set fixed value for '{arr_name}'\")\n", " print(\"-\" * 54 + \"\\n\")\n", " \n", - " # 5. Generate Combinations dynamically based on sweep_arrays\n", " keys = list(sweep_data.keys())\n", " value_lists = [sweep_data[k][\"values\"] for k in keys]\n", " combinations = list(itertools.product(*value_lists))\n", @@ -56,15 +123,15 @@ " print(f\"Total combinations to evaluate: {len(combinations)}\")\n", " print(\"-\" * 60)\n", "\n", - " # 6. Initialize Hydra API\n", " GlobalHydra.instance().clear()\n", " hydra.initialize(version_base=\"1.3\", config_path=\"../configs\") \n", "\n", " for idx, combo in enumerate(combinations):\n", - " # Zip the current combination back to their bash array names\n", + " # --- START ITERATION TIMER ---\n", + " iter_start_time = time.time()\n", + " \n", " combo_dict = dict(zip(keys, combo))\n", " \n", - " # Validation checks: Get datasets from current combo OR fallback to the fixed value\n", " model_val = combo_dict.get(\"models\", \"\")\n", " dataset_val = combo_dict.get(\"datasets\", fixed_params.get(\"datasets\", \"\"))\n", " \n", @@ -72,35 +139,26 @@ " print(f\"[{idx+1}/{len(combinations)}] Skipping invalid combo (Cell model + Simplicial dataset)\")\n", " continue\n", "\n", - " # 7. Dynamically Build Hydra Overrides\n", - " # A. Add the dynamic sweep values\n", - " # 7. Dynamically Build Hydra Overrides\n", " overrides = []\n", " \n", - " # A. Add the dynamic sweep values\n", " for k, v in combo_dict.items():\n", " override_key = sweep_data[k]['override_key']\n", " \n", - " # Catch the special @@@ syntax used in transforms\n", " if \"@@@\" in str(v):\n", - " # The first part is the value for the primary key\n", " parts = str(v).split(\"@@@\")\n", " primary_val = parts[0].strip()\n", " overrides.append(f\"{override_key}={primary_val}\")\n", " \n", - " # The remaining parts are separate key=value overrides\n", " for extra_override in parts[1:]:\n", " if extra_override.strip():\n", " overrides.append(extra_override.strip())\n", " else:\n", " overrides.append(f\"{override_key}={v}\")\n", " \n", - " # B. Add the fixed values\n", " for arr_name, override_key in mapping.items():\n", " if arr_name not in sweep_arrays:\n", " fixed_val = fixed_params.get(arr_name, \"\")\n", " if fixed_val:\n", - " # Apply the same @@@ split logic to fixed parameters just in case\n", " if \"@@@\" in str(fixed_val):\n", " parts = str(fixed_val).split(\"@@@\")\n", " overrides.append(f\"{override_key}={parts[0].strip()}\")\n", @@ -110,18 +168,17 @@ " else:\n", " overrides.append(f\"{override_key}={fixed_val}\")\n", " \n", - " # C. Add environment/testing constants\n", " overrides.extend([\n", " \"train=False\", \n", " \"test=False\",\n", " \"trainer.accelerator=cpu\",\n", - " \"trainer.devices=auto\"\n", + " \"trainer.devices=auto\",\n", + " \"dataset.dataloader_params.batch_size=1\"\n", " ])\n", "\n", " try:\n", " cfg = hydra.compose(config_name=\"run.yaml\", overrides=overrides)\n", " \n", - " # --- DATALOADER INSTANTIATION ---\n", " dataset_loader = hydra.utils.instantiate(cfg.dataset.loader)\n", " loaded_dataset, dataset_dir = dataset_loader.load()\n", " \n", @@ -146,8 +203,7 @@ " \n", " total_params, breakdown = count_detailed_parameters(instantiated_model)\n", " \n", - " # 8. Dynamically construct the result row\n", - " res = {k: combo_dict[k] for k in keys} # Add the sweep values first\n", + " res = {k: combo_dict[k] for k in keys} \n", " res.update({\n", " \"Total_Params\": total_params,\n", " \"Backbone_Params\": breakdown[\"Backbone\"],\n", @@ -157,24 +213,34 @@ " })\n", " results.append(res)\n", " \n", - " # Create a quick log string based on the active sweep elements\n", + " # --- END ITERATION TIMER (SUCCESS) ---\n", + " iter_time = time.time() - iter_start_time\n", + " \n", " log_info = \" \".join([f\"{k[:3]}:{str(v).split('/')[-1]}\" for k, v in combo_dict.items()])\n", - " print(f\"[{idx+1}/{len(combinations)}] SUCCESS | Total: {total_params:,} ({log_info})\")\n", + " print(f\"[{idx+1}/{len(combinations)}] SUCCESS in {iter_time:.2f}s | Total: {total_params:,} ({log_info})\")\n", "\n", " except Exception as e:\n", - " print(f\"[{idx+1}/{len(combinations)}] FAILED | Error: {e}\")\n", + " # --- END ITERATION TIMER (FAIL) ---\n", + " iter_time = time.time() - iter_start_time\n", + " print(f\"[{idx+1}/{len(combinations)}] FAILED in {iter_time:.2f}s | Error: {e}\")\n", "\n", - " # 9. Export Results\n", " print(\"-\" * 60)\n", " df = pd.DataFrame(results)\n", " \n", " if not df.empty and out_path is not None:\n", - " # Sort dynamically by the first sweep parameter and then Backbone_Params\n", " df = df.sort_values(by=[keys[0], \"Backbone_Params\"])\n", " df.to_csv(out_path, index=False)\n", - " print(f\"\\nSaved results to {out_path}\")\n", + " print(f\"Saved results to {out_path}\")\n", " else:\n", - " print(\"No successful combinations to report.\")" + " print(\"No successful combinations to report.\")\n", + "\n", + " end_time_overall = time.time()\n", + " total_time_seconds = end_time_overall - start_time_overall\n", + " formatted_time = str(datetime.timedelta(seconds=int(total_time_seconds)))\n", + " \n", + " print(\"-\" * 60)\n", + " print(f\"Overall Script Time: {formatted_time} (HH:MM:SS)\")\n", + " print(\"-\" * 60)" ] }, { @@ -193,7 +259,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 2, "id": "20ac90c5", "metadata": {}, "outputs": [ @@ -204,7 +270,7 @@ "Parsing parameters from ../scripts/gcn.sh...\n", "\n", "--- Fixed Parameters (Using 1st value from script) ---\n", - "Fixed 'datasets': graph/MUTAG\n", + "Fixed 'datasets': graph/cocitation_cora\n", "Fixed 'proj_dropouts': 0.25\n", "Fixed 'lrs': 0.01\n", "Fixed 'weight_decays': 0\n", @@ -214,12 +280,12 @@ "\n", "Total combinations to evaluate: 18\n", "------------------------------------------------------------\n", - "[1/18] SUCCESS | Total: 18,071 (mod:gcn tra:no_transform num:1 hid:128)\n", - "[2/18] SUCCESS | Total: 68,887 (mod:gcn tra:no_transform num:1 hid:256)\n", - "[3/18] SUCCESS | Total: 34,583 (mod:gcn tra:no_transform num:2 hid:128)\n", - "[4/18] SUCCESS | Total: 134,679 (mod:gcn tra:no_transform num:2 hid:256)\n", - "[5/18] SUCCESS | Total: 67,607 (mod:gcn tra:no_transform num:4 hid:128)\n", - "[6/18] SUCCESS | Total: 266,263 (mod:gcn tra:no_transform num:4 hid:256)\n" + "[1/18] SUCCESS | Total: 1,005,778 (mod:gcn tra:no_transform num:1 hid:512)\n", + "[2/18] SUCCESS | Total: 2,531,538 (mod:gcn tra:no_transform num:1 hid:1024)\n", + "[3/18] SUCCESS | Total: 1,268,434 (mod:gcn tra:no_transform num:2 hid:512)\n", + "[4/18] SUCCESS | Total: 3,581,138 (mod:gcn tra:no_transform num:2 hid:1024)\n", + "[5/18] SUCCESS | Total: 1,793,746 (mod:gcn tra:no_transform num:4 hid:512)\n", + "[6/18] SUCCESS | Total: 5,680,338 (mod:gcn tra:no_transform num:4 hid:1024)\n" ] }, { @@ -234,39 +300,14 @@ "output_type": "stream", "text": [ "\n", - "Applying transforms to 188 graphs...\n" + "Applying transforms to 1 graphs...\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ - "Processing graphs: 100%|██████████| 188/188 [00:02<00:00, 86.60graph/s]\n", - "Done!\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "[7/18] SUCCESS | Total: 24,228 (mod:gcn tra:combined_pe@@@transforms.CombinedPSEs.encodings=[LapPE,RWSE,ElectrostaticPE,HKdiagSE] num:1 hid:128)\n", - "Transform parameters are the same, using existing data_dir: /home/marco/Documents/phd/TopoBench/datasets/graph/TUDataset/MUTAG/CombinedPSEs/2691934090\n", - "[8/18] SUCCESS | Total: 81,060 (mod:gcn tra:combined_pe@@@transforms.CombinedPSEs.encodings=[LapPE,RWSE,ElectrostaticPE,HKdiagSE] num:1 hid:256)\n", - "Transform parameters are the same, using existing data_dir: /home/marco/Documents/phd/TopoBench/datasets/graph/TUDataset/MUTAG/CombinedPSEs/2691934090\n", - "[9/18] SUCCESS | Total: 40,740 (mod:gcn tra:combined_pe@@@transforms.CombinedPSEs.encodings=[LapPE,RWSE,ElectrostaticPE,HKdiagSE] num:2 hid:128)\n", - "Transform parameters are the same, using existing data_dir: /home/marco/Documents/phd/TopoBench/datasets/graph/TUDataset/MUTAG/CombinedPSEs/2691934090\n", - "[10/18] SUCCESS | Total: 146,852 (mod:gcn tra:combined_pe@@@transforms.CombinedPSEs.encodings=[LapPE,RWSE,ElectrostaticPE,HKdiagSE] num:2 hid:256)\n", - "Transform parameters are the same, using existing data_dir: /home/marco/Documents/phd/TopoBench/datasets/graph/TUDataset/MUTAG/CombinedPSEs/2691934090\n", - "[11/18] SUCCESS | Total: 73,764 (mod:gcn tra:combined_pe@@@transforms.CombinedPSEs.encodings=[LapPE,RWSE,ElectrostaticPE,HKdiagSE] num:4 hid:128)\n", - "Transform parameters are the same, using existing data_dir: /home/marco/Documents/phd/TopoBench/datasets/graph/TUDataset/MUTAG/CombinedPSEs/2691934090\n", - "[12/18] SUCCESS | Total: 278,436 (mod:gcn tra:combined_pe@@@transforms.CombinedPSEs.encodings=[LapPE,RWSE,ElectrostaticPE,HKdiagSE] num:4 hid:256)\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "Processing...\n" + "Processing graphs: 0%| | 0/1 [00:00 \u001b[39m\u001b[32m5\u001b[39m \u001b[43mmain\u001b[49m\u001b[43m(\u001b[49m\u001b[43msh_path\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mtarget_arrays\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mout_path\u001b[49m\u001b[43m)\u001b[49m\n", + "\u001b[36mCell\u001b[39m\u001b[36m \u001b[39m\u001b[32mIn[1]\u001b[39m\u001b[32m, line 193\u001b[39m, in \u001b[36mmain\u001b[39m\u001b[34m(sh_path, sweep_arrays, out_path)\u001b[39m\n\u001b[32m 189\u001b[39m loaded_dataset, dataset_dir = dataset_loader.load()\n\u001b[32m 191\u001b[39m transform_config = hydra.utils.instantiate(cfg.transforms) \u001b[38;5;28;01mif\u001b[39;00m cfg.get(\u001b[33m\"\u001b[39m\u001b[33mtransforms\u001b[39m\u001b[33m\"\u001b[39m, \u001b[38;5;28;01mNone\u001b[39;00m) \u001b[38;5;28;01melse\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m\n\u001b[32m--> \u001b[39m\u001b[32m193\u001b[39m preprocessor = \u001b[43mPreProcessor\u001b[49m\u001b[43m(\u001b[49m\u001b[43mloaded_dataset\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mdataset_dir\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mtransform_config\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m 194\u001b[39m dataset_train, dataset_val, dataset_test = preprocessor.load_dataset_splits(cfg.dataset.split_params)\n\u001b[32m 196\u001b[39m datamodule = TBDataloader(\n\u001b[32m 197\u001b[39m dataset_train=dataset_train,\n\u001b[32m 198\u001b[39m dataset_val=dataset_val,\n\u001b[32m 199\u001b[39m dataset_test=dataset_test,\n\u001b[32m 200\u001b[39m **cfg.dataset.get(\u001b[33m\"\u001b[39m\u001b[33mdataloader_params\u001b[39m\u001b[33m\"\u001b[39m, {}),\n\u001b[32m 201\u001b[39m )\n", + "\u001b[36mFile \u001b[39m\u001b[32m~/Documents/phd/TopoBench/topobench/data/preprocessor/preprocessor.py:47\u001b[39m, in \u001b[36mPreProcessor.__init__\u001b[39m\u001b[34m(self, dataset, data_dir, transforms_config, **kwargs)\u001b[39m\n\u001b[32m 45\u001b[39m \u001b[38;5;66;03m# Record the time taken for preprocessing\u001b[39;00m\n\u001b[32m 46\u001b[39m start_time = time.time()\n\u001b[32m---> \u001b[39m\u001b[32m47\u001b[39m \u001b[38;5;28;43msuper\u001b[39;49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\u001b[43m.\u001b[49m\u001b[34;43m__init__\u001b[39;49m\u001b[43m(\u001b[49m\n\u001b[32m 48\u001b[39m \u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[43m.\u001b[49m\u001b[43mprocessed_data_dir\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43;01mNone\u001b[39;49;00m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mpre_transform\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43m*\u001b[49m\u001b[43m*\u001b[49m\u001b[43mkwargs\u001b[49m\n\u001b[32m 49\u001b[39m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m 50\u001b[39m end_time = time.time()\n\u001b[32m 51\u001b[39m \u001b[38;5;28mself\u001b[39m.preprocessing_time = end_time - start_time\n", + "\u001b[36mFile \u001b[39m\u001b[32m~/miniforge3/envs/topobench/lib/python3.11/site-packages/torch_geometric/data/in_memory_dataset.py:81\u001b[39m, in \u001b[36mInMemoryDataset.__init__\u001b[39m\u001b[34m(self, root, transform, pre_transform, pre_filter, log, force_reload)\u001b[39m\n\u001b[32m 72\u001b[39m \u001b[38;5;28;01mdef\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34m__init__\u001b[39m(\n\u001b[32m 73\u001b[39m \u001b[38;5;28mself\u001b[39m,\n\u001b[32m 74\u001b[39m root: Optional[\u001b[38;5;28mstr\u001b[39m] = \u001b[38;5;28;01mNone\u001b[39;00m,\n\u001b[32m (...)\u001b[39m\u001b[32m 79\u001b[39m force_reload: \u001b[38;5;28mbool\u001b[39m = \u001b[38;5;28;01mFalse\u001b[39;00m,\n\u001b[32m 80\u001b[39m ) -> \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[32m---> \u001b[39m\u001b[32m81\u001b[39m \u001b[38;5;28;43msuper\u001b[39;49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\u001b[43m.\u001b[49m\u001b[34;43m__init__\u001b[39;49m\u001b[43m(\u001b[49m\u001b[43mroot\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mtransform\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mpre_transform\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mpre_filter\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mlog\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m 82\u001b[39m \u001b[43m \u001b[49m\u001b[43mforce_reload\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m 84\u001b[39m \u001b[38;5;28mself\u001b[39m._data: Optional[BaseData] = \u001b[38;5;28;01mNone\u001b[39;00m\n\u001b[32m 85\u001b[39m \u001b[38;5;28mself\u001b[39m.slices: Optional[Dict[\u001b[38;5;28mstr\u001b[39m, Tensor]] = \u001b[38;5;28;01mNone\u001b[39;00m\n", + "\u001b[36mFile \u001b[39m\u001b[32m~/miniforge3/envs/topobench/lib/python3.11/site-packages/torch_geometric/data/dataset.py:115\u001b[39m, in \u001b[36mDataset.__init__\u001b[39m\u001b[34m(self, root, transform, pre_transform, pre_filter, log, force_reload)\u001b[39m\n\u001b[32m 112\u001b[39m \u001b[38;5;28mself\u001b[39m._download()\n\u001b[32m 114\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m.has_process:\n\u001b[32m--> \u001b[39m\u001b[32m115\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[43m.\u001b[49m\u001b[43m_process\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n", + "\u001b[36mFile \u001b[39m\u001b[32m~/miniforge3/envs/topobench/lib/python3.11/site-packages/torch_geometric/data/dataset.py:265\u001b[39m, in \u001b[36mDataset._process\u001b[39m\u001b[34m(self)\u001b[39m\n\u001b[32m 262\u001b[39m \u001b[38;5;28mprint\u001b[39m(\u001b[33m'\u001b[39m\u001b[33mProcessing...\u001b[39m\u001b[33m'\u001b[39m, file=sys.stderr)\n\u001b[32m 264\u001b[39m fs.makedirs(\u001b[38;5;28mself\u001b[39m.processed_dir, exist_ok=\u001b[38;5;28;01mTrue\u001b[39;00m)\n\u001b[32m--> \u001b[39m\u001b[32m265\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[43m.\u001b[49m\u001b[43mprocess\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m 267\u001b[39m path = osp.join(\u001b[38;5;28mself\u001b[39m.processed_dir, \u001b[33m'\u001b[39m\u001b[33mpre_transform.pt\u001b[39m\u001b[33m'\u001b[39m)\n\u001b[32m 268\u001b[39m fs.torch_save(_repr(\u001b[38;5;28mself\u001b[39m.pre_transform), path)\n", + "\u001b[36mFile \u001b[39m\u001b[32m~/Documents/phd/TopoBench/topobench/data/preprocessor/preprocessor.py:219\u001b[39m, in \u001b[36mPreProcessor.process\u001b[39m\u001b[34m(self)\u001b[39m\n\u001b[32m 217\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m.pre_transform \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[32m 218\u001b[39m \u001b[38;5;28mprint\u001b[39m(\u001b[33mf\u001b[39m\u001b[33m\"\u001b[39m\u001b[38;5;130;01m\\n\u001b[39;00m\u001b[33mApplying transforms to \u001b[39m\u001b[38;5;132;01m{\u001b[39;00m\u001b[38;5;28mlen\u001b[39m(data_list)\u001b[38;5;132;01m}\u001b[39;00m\u001b[33m graphs...\u001b[39m\u001b[33m\"\u001b[39m)\n\u001b[32m--> \u001b[39m\u001b[32m219\u001b[39m \u001b[38;5;28mself\u001b[39m.data_list = \u001b[43m[\u001b[49m\n\u001b[32m 220\u001b[39m \u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[43m.\u001b[49m\u001b[43mpre_transform\u001b[49m\u001b[43m(\u001b[49m\u001b[43md\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m 221\u001b[39m \u001b[43m \u001b[49m\u001b[38;5;28;43;01mfor\u001b[39;49;00m\u001b[43m \u001b[49m\u001b[43md\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;129;43;01min\u001b[39;49;00m\u001b[43m \u001b[49m\u001b[43mtqdm\u001b[49m\u001b[43m(\u001b[49m\u001b[43mdata_list\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mdesc\u001b[49m\u001b[43m=\u001b[49m\u001b[33;43m\"\u001b[39;49m\u001b[33;43mProcessing graphs\u001b[39;49m\u001b[33;43m\"\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43munit\u001b[49m\u001b[43m=\u001b[49m\u001b[33;43m\"\u001b[39;49m\u001b[33;43mgraph\u001b[39;49m\u001b[33;43m\"\u001b[39;49m\u001b[43m)\u001b[49m\n\u001b[32m 222\u001b[39m \u001b[43m \u001b[49m\u001b[43m]\u001b[49m\n\u001b[32m 223\u001b[39m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[32m 224\u001b[39m \u001b[38;5;28mself\u001b[39m.data_list = data_list\n", + "\u001b[36mFile \u001b[39m\u001b[32m~/Documents/phd/TopoBench/topobench/data/preprocessor/preprocessor.py:220\u001b[39m, in \u001b[36m\u001b[39m\u001b[34m(.0)\u001b[39m\n\u001b[32m 217\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m.pre_transform \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[32m 218\u001b[39m \u001b[38;5;28mprint\u001b[39m(\u001b[33mf\u001b[39m\u001b[33m\"\u001b[39m\u001b[38;5;130;01m\\n\u001b[39;00m\u001b[33mApplying transforms to \u001b[39m\u001b[38;5;132;01m{\u001b[39;00m\u001b[38;5;28mlen\u001b[39m(data_list)\u001b[38;5;132;01m}\u001b[39;00m\u001b[33m graphs...\u001b[39m\u001b[33m\"\u001b[39m)\n\u001b[32m 219\u001b[39m \u001b[38;5;28mself\u001b[39m.data_list = [\n\u001b[32m--> \u001b[39m\u001b[32m220\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[43m.\u001b[49m\u001b[43mpre_transform\u001b[49m\u001b[43m(\u001b[49m\u001b[43md\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m 221\u001b[39m \u001b[38;5;28;01mfor\u001b[39;00m d \u001b[38;5;129;01min\u001b[39;00m tqdm(data_list, desc=\u001b[33m\"\u001b[39m\u001b[33mProcessing graphs\u001b[39m\u001b[33m\"\u001b[39m, unit=\u001b[33m\"\u001b[39m\u001b[33mgraph\u001b[39m\u001b[33m\"\u001b[39m)\n\u001b[32m 222\u001b[39m ]\n\u001b[32m 223\u001b[39m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[32m 224\u001b[39m \u001b[38;5;28mself\u001b[39m.data_list = data_list\n", + "\u001b[36mFile \u001b[39m\u001b[32m~/miniforge3/envs/topobench/lib/python3.11/site-packages/torch_geometric/transforms/base_transform.py:32\u001b[39m, in \u001b[36mBaseTransform.__call__\u001b[39m\u001b[34m(self, data)\u001b[39m\n\u001b[32m 30\u001b[39m \u001b[38;5;28;01mdef\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34m__call__\u001b[39m(\u001b[38;5;28mself\u001b[39m, data: Any) -> Any:\n\u001b[32m 31\u001b[39m \u001b[38;5;66;03m# Shallow-copy the data so that we prevent in-place data modification.\u001b[39;00m\n\u001b[32m---> \u001b[39m\u001b[32m32\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[43m.\u001b[49m\u001b[43mforward\u001b[49m\u001b[43m(\u001b[49m\u001b[43mcopy\u001b[49m\u001b[43m.\u001b[49m\u001b[43mcopy\u001b[49m\u001b[43m(\u001b[49m\u001b[43mdata\u001b[49m\u001b[43m)\u001b[49m\u001b[43m)\u001b[49m\n", + "\u001b[36mFile \u001b[39m\u001b[32m~/miniforge3/envs/topobench/lib/python3.11/site-packages/torch_geometric/transforms/compose.py:24\u001b[39m, in \u001b[36mCompose.forward\u001b[39m\u001b[34m(self, data)\u001b[39m\n\u001b[32m 22\u001b[39m data = [transform(d) \u001b[38;5;28;01mfor\u001b[39;00m d \u001b[38;5;129;01min\u001b[39;00m data]\n\u001b[32m 23\u001b[39m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[32m---> \u001b[39m\u001b[32m24\u001b[39m data = \u001b[43mtransform\u001b[49m\u001b[43m(\u001b[49m\u001b[43mdata\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m 25\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m data\n", + "\u001b[36mFile \u001b[39m\u001b[32m~/miniforge3/envs/topobench/lib/python3.11/site-packages/torch_geometric/transforms/base_transform.py:32\u001b[39m, in \u001b[36mBaseTransform.__call__\u001b[39m\u001b[34m(self, data)\u001b[39m\n\u001b[32m 30\u001b[39m \u001b[38;5;28;01mdef\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34m__call__\u001b[39m(\u001b[38;5;28mself\u001b[39m, data: Any) -> Any:\n\u001b[32m 31\u001b[39m \u001b[38;5;66;03m# Shallow-copy the data so that we prevent in-place data modification.\u001b[39;00m\n\u001b[32m---> \u001b[39m\u001b[32m32\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[43m.\u001b[49m\u001b[43mforward\u001b[49m\u001b[43m(\u001b[49m\u001b[43mcopy\u001b[49m\u001b[43m.\u001b[49m\u001b[43mcopy\u001b[49m\u001b[43m(\u001b[49m\u001b[43mdata\u001b[49m\u001b[43m)\u001b[49m\u001b[43m)\u001b[49m\n", + "\u001b[36mFile \u001b[39m\u001b[32m~/Documents/phd/TopoBench/topobench/transforms/data_transform.py:46\u001b[39m, in \u001b[36mDataTransform.forward\u001b[39m\u001b[34m(self, data)\u001b[39m\n\u001b[32m 31\u001b[39m \u001b[38;5;28;01mdef\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34mforward\u001b[39m(\n\u001b[32m 32\u001b[39m \u001b[38;5;28mself\u001b[39m, data: torch_geometric.data.Data\n\u001b[32m 33\u001b[39m ) -> torch_geometric.data.Data:\n\u001b[32m 34\u001b[39m \u001b[38;5;250m \u001b[39m\u001b[33mr\u001b[39m\u001b[33;03m\"\"\"Forward pass of the lifting.\u001b[39;00m\n\u001b[32m 35\u001b[39m \n\u001b[32m 36\u001b[39m \u001b[33;03m Parameters\u001b[39;00m\n\u001b[32m (...)\u001b[39m\u001b[32m 44\u001b[39m \u001b[33;03m The lifted data.\u001b[39;00m\n\u001b[32m 45\u001b[39m \u001b[33;03m \"\"\"\u001b[39;00m\n\u001b[32m---> \u001b[39m\u001b[32m46\u001b[39m transformed_data = \u001b[38;5;28;43mself\u001b[39;49m\u001b[43m.\u001b[49m\u001b[43mtransform\u001b[49m\u001b[43m(\u001b[49m\u001b[43mdata\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m 47\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m transformed_data\n", + "\u001b[36mFile \u001b[39m\u001b[32m~/miniforge3/envs/topobench/lib/python3.11/site-packages/torch_geometric/transforms/base_transform.py:32\u001b[39m, in \u001b[36mBaseTransform.__call__\u001b[39m\u001b[34m(self, data)\u001b[39m\n\u001b[32m 30\u001b[39m \u001b[38;5;28;01mdef\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34m__call__\u001b[39m(\u001b[38;5;28mself\u001b[39m, data: Any) -> Any:\n\u001b[32m 31\u001b[39m \u001b[38;5;66;03m# Shallow-copy the data so that we prevent in-place data modification.\u001b[39;00m\n\u001b[32m---> \u001b[39m\u001b[32m32\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[43m.\u001b[49m\u001b[43mforward\u001b[49m\u001b[43m(\u001b[49m\u001b[43mcopy\u001b[49m\u001b[43m.\u001b[49m\u001b[43mcopy\u001b[49m\u001b[43m(\u001b[49m\u001b[43mdata\u001b[49m\u001b[43m)\u001b[49m\u001b[43m)\u001b[49m\n", + "\u001b[36mFile \u001b[39m\u001b[32m~/Documents/phd/TopoBench/topobench/transforms/data_manipulations/combined_positional_and_structural_encodings.py:116\u001b[39m, in \u001b[36mCombinedPSEs.forward\u001b[39m\u001b[34m(self, data)\u001b[39m\n\u001b[32m 113\u001b[39m \u001b[38;5;66;03m# Instantiate and apply the encoder\u001b[39;00m\n\u001b[32m 114\u001b[39m \u001b[38;5;66;03m# The encoder naturally uses `current_device` because it reads `data.edge_index.device`\u001b[39;00m\n\u001b[32m 115\u001b[39m encoder = encoding_classes[enc](**enc_params)\n\u001b[32m--> \u001b[39m\u001b[32m116\u001b[39m data = \u001b[43mencoder\u001b[49m\u001b[43m(\u001b[49m\u001b[43mdata\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m 118\u001b[39m \u001b[38;5;66;03m# Safety Net: Ensure the graph is returned to its original device before exiting\u001b[39;00m\n\u001b[32m 119\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m current_device != baseline_device:\n", + "\u001b[36mFile \u001b[39m\u001b[32m~/miniforge3/envs/topobench/lib/python3.11/site-packages/torch_geometric/transforms/base_transform.py:32\u001b[39m, in \u001b[36mBaseTransform.__call__\u001b[39m\u001b[34m(self, data)\u001b[39m\n\u001b[32m 30\u001b[39m \u001b[38;5;28;01mdef\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34m__call__\u001b[39m(\u001b[38;5;28mself\u001b[39m, data: Any) -> Any:\n\u001b[32m 31\u001b[39m \u001b[38;5;66;03m# Shallow-copy the data so that we prevent in-place data modification.\u001b[39;00m\n\u001b[32m---> \u001b[39m\u001b[32m32\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[43m.\u001b[49m\u001b[43mforward\u001b[49m\u001b[43m(\u001b[49m\u001b[43mcopy\u001b[49m\u001b[43m.\u001b[49m\u001b[43mcopy\u001b[49m\u001b[43m(\u001b[49m\u001b[43mdata\u001b[49m\u001b[43m)\u001b[49m\u001b[43m)\u001b[49m\n", + "\u001b[36mFile \u001b[39m\u001b[32m~/Documents/phd/TopoBench/topobench/transforms/data_manipulations/electrostatic_encodings.py:71\u001b[39m, in \u001b[36mElectrostaticPE.forward\u001b[39m\u001b[34m(self, data)\u001b[39m\n\u001b[32m 69\u001b[39m \u001b[38;5;66;03m# Exact Method (Original CPU NumPy)\u001b[39;00m\n\u001b[32m 70\u001b[39m t0 = time.time()\n\u001b[32m---> \u001b[39m\u001b[32m71\u001b[39m pe_numpy = \u001b[38;5;28;43mself\u001b[39;49m\u001b[43m.\u001b[49m\u001b[43m_compute_numpy\u001b[49m\u001b[43m(\u001b[49m\u001b[43mdata\u001b[49m\u001b[43m.\u001b[49m\u001b[43medge_index\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mdata\u001b[49m\u001b[43m.\u001b[49m\u001b[43mnum_nodes\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m 72\u001b[39m t_numpy = time.time() - t0\n\u001b[32m 73\u001b[39m \u001b[38;5;28mprint\u001b[39m(\u001b[33mf\u001b[39m\u001b[33m\"\u001b[39m\u001b[33mExact compute time: \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mt_numpy\u001b[38;5;132;01m:\u001b[39;00m\u001b[33m.4f\u001b[39m\u001b[38;5;132;01m}\u001b[39;00m\u001b[33ms\u001b[39m\u001b[33m\"\u001b[39m)\n", + "\u001b[36mFile \u001b[39m\u001b[32m~/Documents/phd/TopoBench/topobench/transforms/data_manipulations/electrostatic_encodings.py:235\u001b[39m, in \u001b[36mElectrostaticPE._compute_numpy\u001b[39m\u001b[34m(self, edge_index, num_nodes)\u001b[39m\n\u001b[32m 232\u001b[39m A.fill_diagonal_(\u001b[32m0\u001b[39m)\n\u001b[32m 233\u001b[39m DinvA = Dinv.matmul(A)\n\u001b[32m--> \u001b[39m\u001b[32m235\u001b[39m evals, evecs = \u001b[43mnp\u001b[49m\u001b[43m.\u001b[49m\u001b[43mlinalg\u001b[49m\u001b[43m.\u001b[49m\u001b[43meigh\u001b[49m\u001b[43m(\u001b[49m\u001b[43mL\u001b[49m\u001b[43m.\u001b[49m\u001b[43mnumpy\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m 236\u001b[39m evals = torch.from_numpy(evals)\n\u001b[32m 237\u001b[39m evecs = torch.from_numpy(evecs)\n", + "\u001b[36mFile \u001b[39m\u001b[32m~/miniforge3/envs/topobench/lib/python3.11/site-packages/numpy/linalg/linalg.py:1487\u001b[39m, in \u001b[36meigh\u001b[39m\u001b[34m(a, UPLO)\u001b[39m\n\u001b[32m 1484\u001b[39m gufunc = _umath_linalg.eigh_up\n\u001b[32m 1486\u001b[39m signature = \u001b[33m'\u001b[39m\u001b[33mD->dD\u001b[39m\u001b[33m'\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m isComplexType(t) \u001b[38;5;28;01melse\u001b[39;00m \u001b[33m'\u001b[39m\u001b[33md->dd\u001b[39m\u001b[33m'\u001b[39m\n\u001b[32m-> \u001b[39m\u001b[32m1487\u001b[39m w, vt = \u001b[43mgufunc\u001b[49m\u001b[43m(\u001b[49m\u001b[43ma\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43msignature\u001b[49m\u001b[43m=\u001b[49m\u001b[43msignature\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mextobj\u001b[49m\u001b[43m=\u001b[49m\u001b[43mextobj\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m 1488\u001b[39m w = w.astype(_realType(result_t), copy=\u001b[38;5;28;01mFalse\u001b[39;00m)\n\u001b[32m 1489\u001b[39m vt = vt.astype(result_t, copy=\u001b[38;5;28;01mFalse\u001b[39;00m)\n", + "\u001b[31mKeyboardInterrupt\u001b[39m: " ] } ],