Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ dependencies = [
"plotly>=6.2.0",
"logomaker>=0.8",
"networkx>=3.3",
"scikit-learn>=1.3",
"upsetplot"
]

Expand Down Expand Up @@ -58,6 +59,7 @@ Issues = "https://github.com/Multiomics-Analytics-Group/InstaNexus/issues"

[project.scripts]
instanexus = "instanexus.main:cli"
instanexus-optimize = "instanexus.optimize:cli"

# --- TOOL CONFIGURATIONS ---

Expand Down
119 changes: 118 additions & 1 deletion scripts/optimization/analyze_optimization.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,13 @@
"""

import glob
import json
import os
from pathlib import Path

import matplotlib.pyplot as plt
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
from matplotlib.colors import LinearSegmentedColormap

# --- CONFIGURATION ---
Expand Down Expand Up @@ -203,6 +205,121 @@ def save_detailed_rankings(df_best, output_dir, mode_name):
final_df.to_csv(out_path, index=False)


def combine_json_to_csv(
run: str,
type_method: str,
type_sequence: str,
base_path: Path = Path("outputs"),
) -> None:
"""Walks output directories, reads JSON stats files, and saves a combined CSV.

Args:
run: Run identifier (e.g. 'bsa', 'ma1').
type_method: Assembly method prefix used in the JSON filename (e.g. 'scaffolds').
type_sequence: Sequence type suffix used in the JSON filename (e.g. 'contigs').
base_path: Root outputs folder.
"""
run_path = Path(base_path) / run
dataframes = []
files_added = 0

for root, dirs, _ in os.walk(run_path):
for dir_name in dirs:
json_path = Path(root) / dir_name / "statistics" / f"{type_method}_{type_sequence}_stats.json"
if json_path.exists():
try:
with open(json_path) as f:
data = json.load(f)
df = pd.json_normalize(data)
df["source"] = dir_name
dataframes.append(df)
files_added += 1
except Exception as e:
print(f"Error loading {json_path}: {e}")

if not dataframes:
print(f"No stats files found under {run_path}.")
return

combined_df = pd.concat(dataframes, ignore_index=True)

if "ass_method" in combined_df.columns:
combined_df["ass_method"] = combined_df["ass_method"].fillna("greedy")

combined_df["sequence_type"] = type_sequence
combined_df["method_type"] = type_method
combined_df["run"] = run

output_file = run_path / f"{type_sequence}_combined_stats.csv"
combined_df.to_csv(output_file, index=False)
print(f"Combined stats saved to: {output_file} ({files_added} files merged)")


def plot_coverages_from_runs(
runs: list,
base_path: Path = Path("outputs"),
combination_folder: str = "",
contigs_json: str = "contigs_stats.json",
scaffolds_json: str = "scaffolds_stats.json",
save: bool = False,
output_dir: Path = Path("."),
) -> None:
"""Plots coverage barplots for contigs and scaffolds across multiple runs.

Args:
runs: List of run identifiers to include (e.g. ['bsa', 'nb1']).
base_path: Root outputs folder.
combination_folder: Sub-folder name for the specific parameter combination.
contigs_json: Filename of the contigs stats JSON (default: contigs_stats.json).
scaffolds_json: Filename of the scaffolds stats JSON (default: scaffolds_stats.json).
save: If True, saves plots as PNG files.
output_dir: Directory where PNG files are saved when save=True.
"""
base_path = Path(base_path)
contig_coverages: list = []
scaffold_coverages: list = []
labels: list = []

for run in runs:
stats_path = base_path / run / combination_folder / "statistics"
if not stats_path.exists():
print(f"[{run}] Missing statistics folder: {stats_path}")
continue

for coverage_list, fname in [(contig_coverages, contigs_json), (scaffold_coverages, scaffolds_json)]:
json_path = stats_path / fname
if json_path.exists():
try:
with open(json_path) as f:
coverage_list.append(json.load(f).get("coverage", 0))
except Exception as e:
print(f"[{run}] Error reading {fname}: {e}")
coverage_list.append(0)
else:
print(f"[{run}] {fname} not found.")
coverage_list.append(0)

labels.append(run)

for coverages, color, title, suffix in [
(contig_coverages, "mediumslateblue", "Contigs Coverage per Run", "contigs"),
(scaffold_coverages, "seagreen", "Scaffolds Coverage per Run", "scaffolds"),
]:
plt.figure(figsize=(10, 4))
plt.bar(labels, coverages, color=color)
plt.ylabel("Coverage")
plt.title(title)
plt.xticks(rotation=45, ha="right")
plt.tight_layout()

if save:
output_dir = Path(output_dir)
output_dir.mkdir(parents=True, exist_ok=True)
plt.savefig(output_dir / f"{suffix}_coverage.png", dpi=300)

plt.show()


def main():
print("--- Starting Optimization Analysis (Seaborn Edition) ---")

Expand Down
119 changes: 82 additions & 37 deletions scripts/optimization/grid_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
import json
import logging
import sys
import tempfile
import time
from concurrent.futures import ProcessPoolExecutor, as_completed
from pathlib import Path
Expand Down Expand Up @@ -82,59 +83,55 @@ def load_grid_params(json_path: Path, mode: str) -> List[Dict[str, Any]]:


def compute_final_ranking(df_results: pd.DataFrame) -> pd.DataFrame:
"""
Applies MinMax scaling to normalize metrics and computes the Composite Score.

WEIGHTING STRATEGY: 'Aggressive Consolidation Split'
----------------------------------------------------
1. Coverage (0.35): DOMINANT.
Rationale: The primary goal is to recover the protein sequence. High N50
is useless if we only recover 10% of the target.

2. N50 (0.25) & Scaffold Count (0.25): STRUCTURAL (50% Combined).
Rationale: We strongly penalize fragmentation. We want the algorithm to
prioritize merging contigs into longer, fewer scaffolds over keeping
them separate to maximize local identity.

3. Mean Identity (0.15): QUALITY.
Rationale: Lower weight because input data is usually pre-filtered
(e.g., >80% identity during mapping). Differences between 95% and 99%
are less critical than differences in coverage or fragmentation.
"""Applies MinMax scaling and computes the Composite Score for ranking.

As defined in Reverenna et al., bioRxiv 2025.

Formula:
composite_score = 0.5 * coverage_norm
+ 0.3 * N50_norm
+ 0.1 * (1 - scaffolds_count_norm) # inverted: fewer = better
+ 0.1 * max_length_norm

Note: mean_identity is collected for reporting but is NOT part of this
formula. Do not add it here. For benchmarking against other tools use the
AQS formula defined in Reverenna et al., MCP 2026.

Args:
df_results: DataFrame with one row per parameter combination.

Returns:
df_results with a composite_score column, sorted descending.
"""
if df_results.empty:
return df_results

# Metrics to use for scoring
metrics = ["coverage", "N50", "mean_identity", "scaffolds_count"]
# Metrics used in the Composite Score (Reverenna et al., bioRxiv 2025)
metrics = ["coverage", "N50", "scaffolds_count", "max_length"]

# Check if metrics exist
available_metrics = [m for m in metrics if m in df_results.columns]
if len(available_metrics) != len(metrics):
logger.warning("Some metrics missing from results. Skipping scoring.")
return df_results

df_scoring = df_results[metrics].copy().fillna(0)

# Normalize (0 to 1)
# Normalize all metrics to [0, 1]
scaler = MinMaxScaler()
scaled_values = scaler.fit_transform(df_scoring)
df_scaled = pd.DataFrame(scaled_values, columns=metrics, index=df_results.index)

# Invert 'scaffolds_count' because fewer is better.
# Formula: 1 - normalized_val (where 1 becomes best/fewest, 0 becomes worst/most)
# Invert scaffolds_count: fewer assembled sequences = better
df_scaled["scaffolds_count"] = 1 - df_scaled["scaffolds_count"]

# Define Weights
weights = {"coverage": 0.35, "N50": 0.25, "scaffolds_count": 0.25, "mean_identity": 0.15}
# Composite Score weights (Reverenna et al., bioRxiv 2025)
weights = {"coverage": 0.5, "N50": 0.3, "scaffolds_count": 0.1, "max_length": 0.1}

# Calculate Weighted Sum
composite_scores = df_scaled[list(weights.keys())].dot(pd.Series(weights))

# Merge back
df_final = df_results.copy()
df_final["composite_score"] = composite_scores

# Sort by score descending
return df_final.sort_values(by="composite_score", ascending=False)


Expand Down Expand Up @@ -193,18 +190,20 @@ def evaluate_combination(

df_mapped = visualization.create_dataframe_from_mapped_sequences(mapped_scaffolds)

stats = helpers.compute_assembly_statistics(
df=df_mapped,
sequence_type="scaffolds",
output_folder="", # We don't save individual JSONs to save IO/Time
reference=protein_norm,
)
with tempfile.TemporaryDirectory() as _tmpdir:
stats = helpers.compute_assembly_statistics(
df=df_mapped,
sequence_type="scaffolds",
output_folder=_tmpdir,
reference=protein_norm,
)

return {
**params,
"scaffolds_count": len(scaffolds),
"coverage": stats.get("coverage", 0),
"N50": stats.get("N50", 0),
"max_length": stats.get("max_length", 0),
"mean_identity": stats.get("mean_identity", 0),
"total_mismatches": stats.get("total_mismatches", 0),
"duration_sec": round(duration, 2),
Expand All @@ -216,6 +215,49 @@ def evaluate_combination(
return {**params, "error": str(e)}


def _build_instanexus_command(
input_csv: str,
mode: str,
best_params: Dict[str, Any],
best_score: float,
) -> str:
"""Builds a ready-to-run instanexus CLI command from the best grid parameters.

Args:
input_csv: Path to the input CSV used during the search.
mode: Assembly mode (e.g. 'dbg_weighted').
best_params: Dict of parameter name → best value from the grid.
best_score: Best composite score achieved.

Returns:
A formatted string with the suggested instanexus command.
"""
# Map grid-search parameter names to instanexus CLI flags
cli_map: Dict[str, str] = {
"fdr": "--fdr",
"kmer_size": "--kmer-size",
"min_overlap": "--min-overlap",
"size_threshold": "--size-threshold",
}
int_params = {"kmer_size", "min_overlap", "size_threshold"}

parts = [
"instanexus",
f"--input-csv {input_csv}",
f"--assembly-mode {mode}",
]

for key, flag in cli_map.items():
if key in best_params:
val = int(best_params[key]) if key in int_params else best_params[key]
parts.append(f"{flag} {val}")

if int(best_params.get("refine_rounds", 0)) > 0:
parts.append("--refine")

return f"Based on optimization (composite_score={best_score:.3f}), run:\n {' '.join(parts)}"


def main():
parser = argparse.ArgumentParser(description="Hyperparameter Grid Search for InstaNexus.")
parser.add_argument("--input-csv", required=True, help="Path to input cleaned CSV.")
Expand Down Expand Up @@ -276,7 +318,7 @@ def main():
df_valid = df_results[df_results["error"].isnull()].copy()

if not df_valid.empty:
logger.info("Computing final ranking (Aggressive Consolidation Split)...")
logger.info("Computing final ranking (Composite Score, Reverenna et al., bioRxiv 2025)...")
df_ranked = compute_final_ranking(df_valid)

df_errors = df_results[df_results["error"].notnull()]
Expand All @@ -296,8 +338,11 @@ def main():
f" Cov: {best['coverage'] * 100:.1f}% | N50: {best['N50']} | Scaffolds: {best['scaffolds_count']}"
)

best_params = {k: best[k] for k in combinations[0].keys() if k in best}
best_params = {
k: best[k].item() if hasattr(best[k], "item") else best[k] for k in combinations[0].keys() if k in best
}
logger.info(f" Params: {json.dumps(best_params, indent=2)}")
logger.info(_build_instanexus_command(args.input_csv, args.mode, best_params, best["composite_score"]))
else:
logger.warning("All runs failed or produced no valid assemblies.")
csv_out = output_dir / f"grid_{args.mode}_{run_name}_FAILED.csv"
Expand Down
21 changes: 21 additions & 0 deletions src/instanexus/optimize.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
#!/usr/bin/env python
"""CLI entry point for the instanexus-optimize command."""

import sys
from pathlib import Path


def cli() -> None:
"""Launch the hyperparameter grid search optimizer.

Delegates to scripts/optimization/grid_search.py, which contains
the full implementation. The package entry point exists here so that
`instanexus-optimize` is available after `uv sync`.
"""
_scripts_opt = Path(__file__).resolve().parents[2] / "scripts" / "optimization"
if str(_scripts_opt) not in sys.path:
sys.path.insert(0, str(_scripts_opt))

import grid_search # type: ignore[import-not-found]

grid_search.main()
Loading
Loading