Skip to content

Commit dc06b69

Browse files
Merge pull request #28 from Multiomics-Analytics-Group/27-fix-feature-separate-composite-score-grid-search-from-aqs-benchmark-and-add-instanexus-optimize-cli
27 fix feature separate composite score grid search from aqs benchmark and add instanexus optimize cli
2 parents a6dc65e + 102cc6a commit dc06b69

5 files changed

Lines changed: 493 additions & 38 deletions

File tree

pyproject.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ dependencies = [
3131
"plotly>=6.2.0",
3232
"logomaker>=0.8",
3333
"networkx>=3.3",
34+
"scikit-learn>=1.3",
3435
"upsetplot"
3536
]
3637

@@ -58,6 +59,7 @@ Issues = "https://github.com/Multiomics-Analytics-Group/InstaNexus/issues"
5859

5960
[project.scripts]
6061
instanexus = "instanexus.main:cli"
62+
instanexus-optimize = "instanexus.optimize:cli"
6163

6264
# --- TOOL CONFIGURATIONS ---
6365

scripts/optimization/analyze_optimization.py

Lines changed: 118 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,11 +7,13 @@
77
"""
88

99
import glob
10+
import json
1011
import os
1112
from pathlib import Path
13+
14+
import matplotlib.pyplot as plt
1215
import pandas as pd
1316
import seaborn as sns
14-
import matplotlib.pyplot as plt
1517
from matplotlib.colors import LinearSegmentedColormap
1618

1719
# --- CONFIGURATION ---
@@ -203,6 +205,121 @@ def save_detailed_rankings(df_best, output_dir, mode_name):
203205
final_df.to_csv(out_path, index=False)
204206

205207

208+
def combine_json_to_csv(
209+
run: str,
210+
type_method: str,
211+
type_sequence: str,
212+
base_path: Path = Path("outputs"),
213+
) -> None:
214+
"""Walks output directories, reads JSON stats files, and saves a combined CSV.
215+
216+
Args:
217+
run: Run identifier (e.g. 'bsa', 'ma1').
218+
type_method: Assembly method prefix used in the JSON filename (e.g. 'scaffolds').
219+
type_sequence: Sequence type suffix used in the JSON filename (e.g. 'contigs').
220+
base_path: Root outputs folder.
221+
"""
222+
run_path = Path(base_path) / run
223+
dataframes = []
224+
files_added = 0
225+
226+
for root, dirs, _ in os.walk(run_path):
227+
for dir_name in dirs:
228+
json_path = Path(root) / dir_name / "statistics" / f"{type_method}_{type_sequence}_stats.json"
229+
if json_path.exists():
230+
try:
231+
with open(json_path) as f:
232+
data = json.load(f)
233+
df = pd.json_normalize(data)
234+
df["source"] = dir_name
235+
dataframes.append(df)
236+
files_added += 1
237+
except Exception as e:
238+
print(f"Error loading {json_path}: {e}")
239+
240+
if not dataframes:
241+
print(f"No stats files found under {run_path}.")
242+
return
243+
244+
combined_df = pd.concat(dataframes, ignore_index=True)
245+
246+
if "ass_method" in combined_df.columns:
247+
combined_df["ass_method"] = combined_df["ass_method"].fillna("greedy")
248+
249+
combined_df["sequence_type"] = type_sequence
250+
combined_df["method_type"] = type_method
251+
combined_df["run"] = run
252+
253+
output_file = run_path / f"{type_sequence}_combined_stats.csv"
254+
combined_df.to_csv(output_file, index=False)
255+
print(f"Combined stats saved to: {output_file} ({files_added} files merged)")
256+
257+
258+
def plot_coverages_from_runs(
259+
runs: list,
260+
base_path: Path = Path("outputs"),
261+
combination_folder: str = "",
262+
contigs_json: str = "contigs_stats.json",
263+
scaffolds_json: str = "scaffolds_stats.json",
264+
save: bool = False,
265+
output_dir: Path = Path("."),
266+
) -> None:
267+
"""Plots coverage barplots for contigs and scaffolds across multiple runs.
268+
269+
Args:
270+
runs: List of run identifiers to include (e.g. ['bsa', 'nb1']).
271+
base_path: Root outputs folder.
272+
combination_folder: Sub-folder name for the specific parameter combination.
273+
contigs_json: Filename of the contigs stats JSON (default: contigs_stats.json).
274+
scaffolds_json: Filename of the scaffolds stats JSON (default: scaffolds_stats.json).
275+
save: If True, saves plots as PNG files.
276+
output_dir: Directory where PNG files are saved when save=True.
277+
"""
278+
base_path = Path(base_path)
279+
contig_coverages: list = []
280+
scaffold_coverages: list = []
281+
labels: list = []
282+
283+
for run in runs:
284+
stats_path = base_path / run / combination_folder / "statistics"
285+
if not stats_path.exists():
286+
print(f"[{run}] Missing statistics folder: {stats_path}")
287+
continue
288+
289+
for coverage_list, fname in [(contig_coverages, contigs_json), (scaffold_coverages, scaffolds_json)]:
290+
json_path = stats_path / fname
291+
if json_path.exists():
292+
try:
293+
with open(json_path) as f:
294+
coverage_list.append(json.load(f).get("coverage", 0))
295+
except Exception as e:
296+
print(f"[{run}] Error reading {fname}: {e}")
297+
coverage_list.append(0)
298+
else:
299+
print(f"[{run}] {fname} not found.")
300+
coverage_list.append(0)
301+
302+
labels.append(run)
303+
304+
for coverages, color, title, suffix in [
305+
(contig_coverages, "mediumslateblue", "Contigs Coverage per Run", "contigs"),
306+
(scaffold_coverages, "seagreen", "Scaffolds Coverage per Run", "scaffolds"),
307+
]:
308+
plt.figure(figsize=(10, 4))
309+
plt.bar(labels, coverages, color=color)
310+
plt.ylabel("Coverage")
311+
plt.title(title)
312+
plt.xticks(rotation=45, ha="right")
313+
plt.tight_layout()
314+
315+
if save:
316+
output_dir = Path(output_dir)
317+
output_dir.mkdir(parents=True, exist_ok=True)
318+
plt.savefig(output_dir / f"{suffix}_coverage.png", dpi=300)
319+
320+
plt.show()
321+
322+
206323
def main():
207324
print("--- Starting Optimization Analysis (Seaborn Edition) ---")
208325

scripts/optimization/grid_search.py

Lines changed: 82 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@
3838
import json
3939
import logging
4040
import sys
41+
import tempfile
4142
import time
4243
from concurrent.futures import ProcessPoolExecutor, as_completed
4344
from pathlib import Path
@@ -82,59 +83,55 @@ def load_grid_params(json_path: Path, mode: str) -> List[Dict[str, Any]]:
8283

8384

8485
def compute_final_ranking(df_results: pd.DataFrame) -> pd.DataFrame:
85-
"""
86-
Applies MinMax scaling to normalize metrics and computes the Composite Score.
87-
88-
WEIGHTING STRATEGY: 'Aggressive Consolidation Split'
89-
----------------------------------------------------
90-
1. Coverage (0.35): DOMINANT.
91-
Rationale: The primary goal is to recover the protein sequence. High N50
92-
is useless if we only recover 10% of the target.
93-
94-
2. N50 (0.25) & Scaffold Count (0.25): STRUCTURAL (50% Combined).
95-
Rationale: We strongly penalize fragmentation. We want the algorithm to
96-
prioritize merging contigs into longer, fewer scaffolds over keeping
97-
them separate to maximize local identity.
98-
99-
3. Mean Identity (0.15): QUALITY.
100-
Rationale: Lower weight because input data is usually pre-filtered
101-
(e.g., >80% identity during mapping). Differences between 95% and 99%
102-
are less critical than differences in coverage or fragmentation.
86+
"""Applies MinMax scaling and computes the Composite Score for ranking.
87+
88+
As defined in Reverenna et al., bioRxiv 2025.
89+
90+
Formula:
91+
composite_score = 0.5 * coverage_norm
92+
+ 0.3 * N50_norm
93+
+ 0.1 * (1 - scaffolds_count_norm) # inverted: fewer = better
94+
+ 0.1 * max_length_norm
95+
96+
Note: mean_identity is collected for reporting but is NOT part of this
97+
formula. Do not add it here. For benchmarking against other tools use the
98+
AQS formula defined in Reverenna et al., MCP 2026.
99+
100+
Args:
101+
df_results: DataFrame with one row per parameter combination.
102+
103+
Returns:
104+
df_results with a composite_score column, sorted descending.
103105
"""
104106
if df_results.empty:
105107
return df_results
106108

107-
# Metrics to use for scoring
108-
metrics = ["coverage", "N50", "mean_identity", "scaffolds_count"]
109+
# Metrics used in the Composite Score (Reverenna et al., bioRxiv 2025)
110+
metrics = ["coverage", "N50", "scaffolds_count", "max_length"]
109111

110-
# Check if metrics exist
111112
available_metrics = [m for m in metrics if m in df_results.columns]
112113
if len(available_metrics) != len(metrics):
113114
logger.warning("Some metrics missing from results. Skipping scoring.")
114115
return df_results
115116

116117
df_scoring = df_results[metrics].copy().fillna(0)
117118

118-
# Normalize (0 to 1)
119+
# Normalize all metrics to [0, 1]
119120
scaler = MinMaxScaler()
120121
scaled_values = scaler.fit_transform(df_scoring)
121122
df_scaled = pd.DataFrame(scaled_values, columns=metrics, index=df_results.index)
122123

123-
# Invert 'scaffolds_count' because fewer is better.
124-
# Formula: 1 - normalized_val (where 1 becomes best/fewest, 0 becomes worst/most)
124+
# Invert scaffolds_count: fewer assembled sequences = better
125125
df_scaled["scaffolds_count"] = 1 - df_scaled["scaffolds_count"]
126126

127-
# Define Weights
128-
weights = {"coverage": 0.35, "N50": 0.25, "scaffolds_count": 0.25, "mean_identity": 0.15}
127+
# Composite Score weights (Reverenna et al., bioRxiv 2025)
128+
weights = {"coverage": 0.5, "N50": 0.3, "scaffolds_count": 0.1, "max_length": 0.1}
129129

130-
# Calculate Weighted Sum
131130
composite_scores = df_scaled[list(weights.keys())].dot(pd.Series(weights))
132131

133-
# Merge back
134132
df_final = df_results.copy()
135133
df_final["composite_score"] = composite_scores
136134

137-
# Sort by score descending
138135
return df_final.sort_values(by="composite_score", ascending=False)
139136

140137

@@ -193,18 +190,20 @@ def evaluate_combination(
193190

194191
df_mapped = visualization.create_dataframe_from_mapped_sequences(mapped_scaffolds)
195192

196-
stats = helpers.compute_assembly_statistics(
197-
df=df_mapped,
198-
sequence_type="scaffolds",
199-
output_folder="", # We don't save individual JSONs to save IO/Time
200-
reference=protein_norm,
201-
)
193+
with tempfile.TemporaryDirectory() as _tmpdir:
194+
stats = helpers.compute_assembly_statistics(
195+
df=df_mapped,
196+
sequence_type="scaffolds",
197+
output_folder=_tmpdir,
198+
reference=protein_norm,
199+
)
202200

203201
return {
204202
**params,
205203
"scaffolds_count": len(scaffolds),
206204
"coverage": stats.get("coverage", 0),
207205
"N50": stats.get("N50", 0),
206+
"max_length": stats.get("max_length", 0),
208207
"mean_identity": stats.get("mean_identity", 0),
209208
"total_mismatches": stats.get("total_mismatches", 0),
210209
"duration_sec": round(duration, 2),
@@ -216,6 +215,49 @@ def evaluate_combination(
216215
return {**params, "error": str(e)}
217216

218217

218+
def _build_instanexus_command(
219+
input_csv: str,
220+
mode: str,
221+
best_params: Dict[str, Any],
222+
best_score: float,
223+
) -> str:
224+
"""Builds a ready-to-run instanexus CLI command from the best grid parameters.
225+
226+
Args:
227+
input_csv: Path to the input CSV used during the search.
228+
mode: Assembly mode (e.g. 'dbg_weighted').
229+
best_params: Dict of parameter name → best value from the grid.
230+
best_score: Best composite score achieved.
231+
232+
Returns:
233+
A formatted string with the suggested instanexus command.
234+
"""
235+
# Map grid-search parameter names to instanexus CLI flags
236+
cli_map: Dict[str, str] = {
237+
"fdr": "--fdr",
238+
"kmer_size": "--kmer-size",
239+
"min_overlap": "--min-overlap",
240+
"size_threshold": "--size-threshold",
241+
}
242+
int_params = {"kmer_size", "min_overlap", "size_threshold"}
243+
244+
parts = [
245+
"instanexus",
246+
f"--input-csv {input_csv}",
247+
f"--assembly-mode {mode}",
248+
]
249+
250+
for key, flag in cli_map.items():
251+
if key in best_params:
252+
val = int(best_params[key]) if key in int_params else best_params[key]
253+
parts.append(f"{flag} {val}")
254+
255+
if int(best_params.get("refine_rounds", 0)) > 0:
256+
parts.append("--refine")
257+
258+
return f"Based on optimization (composite_score={best_score:.3f}), run:\n {' '.join(parts)}"
259+
260+
219261
def main():
220262
parser = argparse.ArgumentParser(description="Hyperparameter Grid Search for InstaNexus.")
221263
parser.add_argument("--input-csv", required=True, help="Path to input cleaned CSV.")
@@ -276,7 +318,7 @@ def main():
276318
df_valid = df_results[df_results["error"].isnull()].copy()
277319

278320
if not df_valid.empty:
279-
logger.info("Computing final ranking (Aggressive Consolidation Split)...")
321+
logger.info("Computing final ranking (Composite Score, Reverenna et al., bioRxiv 2025)...")
280322
df_ranked = compute_final_ranking(df_valid)
281323

282324
df_errors = df_results[df_results["error"].notnull()]
@@ -296,8 +338,11 @@ def main():
296338
f" Cov: {best['coverage'] * 100:.1f}% | N50: {best['N50']} | Scaffolds: {best['scaffolds_count']}"
297339
)
298340

299-
best_params = {k: best[k] for k in combinations[0].keys() if k in best}
341+
best_params = {
342+
k: best[k].item() if hasattr(best[k], "item") else best[k] for k in combinations[0].keys() if k in best
343+
}
300344
logger.info(f" Params: {json.dumps(best_params, indent=2)}")
345+
logger.info(_build_instanexus_command(args.input_csv, args.mode, best_params, best["composite_score"]))
301346
else:
302347
logger.warning("All runs failed or produced no valid assemblies.")
303348
csv_out = output_dir / f"grid_{args.mode}_{run_name}_FAILED.csv"

src/instanexus/optimize.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
#!/usr/bin/env python
2+
"""CLI entry point for the instanexus-optimize command."""
3+
4+
import sys
5+
from pathlib import Path
6+
7+
8+
def cli() -> None:
9+
"""Launch the hyperparameter grid search optimizer.
10+
11+
Delegates to scripts/optimization/grid_search.py, which contains
12+
the full implementation. The package entry point exists here so that
13+
`instanexus-optimize` is available after `uv sync`.
14+
"""
15+
_scripts_opt = Path(__file__).resolve().parents[2] / "scripts" / "optimization"
16+
if str(_scripts_opt) not in sys.path:
17+
sys.path.insert(0, str(_scripts_opt))
18+
19+
import grid_search # type: ignore[import-not-found]
20+
21+
grid_search.main()

0 commit comments

Comments
 (0)