|
1 | 1 | """CLI command for differential expression analysis.""" |
2 | 2 |
|
3 | 3 | import argparse |
| 4 | +import json |
4 | 5 | import sys |
5 | 6 | from pathlib import Path |
6 | 7 | import logging |
|
22 | 23 | logger = logging.getLogger("kompot.cli") |
23 | 24 |
|
24 | 25 |
|
| 26 | +def _json_default(obj): |
| 27 | + """JSON serializer for numpy types.""" |
| 28 | + import numpy as np |
| 29 | + |
| 30 | + if isinstance(obj, (np.integer,)): |
| 31 | + return int(obj) |
| 32 | + if isinstance(obj, (np.floating,)): |
| 33 | + return float(obj) |
| 34 | + if isinstance(obj, np.ndarray): |
| 35 | + return obj.tolist() |
| 36 | + raise TypeError(f"Object of type {type(obj).__name__} is not JSON serializable") |
| 37 | + |
| 38 | + |
25 | 39 | def add_de_parser(subparsers) -> argparse.ArgumentParser: |
26 | 40 | """ |
27 | 41 | Add differential expression subcommand parser. |
@@ -165,6 +179,14 @@ def add_de_parser(subparsers) -> argparse.ArgumentParser: |
165 | 179 | help="Estimate per-gene heteroscedastic noise from squared residuals to deflate significance for high-noise genes", |
166 | 180 | ) |
167 | 181 |
|
| 182 | + # Dry run |
| 183 | + parser.add_argument( |
| 184 | + "--dry-run", |
| 185 | + action="store_true", |
| 186 | + help="Estimate resource requirements instead of running the analysis. " |
| 187 | + "JSON to stdout, human-readable report to stderr. -o/--output is ignored.", |
| 188 | + ) |
| 189 | + |
168 | 190 | # Compute configuration |
169 | 191 | parser.add_argument( |
170 | 192 | "--use-gpu", |
@@ -192,8 +214,8 @@ def run_de(args): |
192 | 214 | args |
193 | 215 | Parsed arguments from argparse |
194 | 216 | """ |
195 | | - # Validate output arguments |
196 | | - if not args.output and not args.table_output: |
| 217 | + # Validate output arguments (not required for dry-run) |
| 218 | + if not args.dry_run and not args.output and not args.table_output: |
197 | 219 | logger.error("Either --output or --table-output must be specified") |
198 | 220 | sys.exit(1) |
199 | 221 |
|
@@ -245,6 +267,7 @@ def run_de(args): |
245 | 267 | "command", |
246 | 268 | "use_gpu", |
247 | 269 | "threads", |
| 270 | + "dry_run", |
248 | 271 | ] |
249 | 272 | } |
250 | 273 |
|
@@ -355,23 +378,45 @@ def run_de(args): |
355 | 378 | output_kwargs["return_full_results"] = True |
356 | 379 | output = OutputSettings(**output_kwargs) if output_kwargs else None |
357 | 380 |
|
| 381 | + # Build shared call kwargs |
| 382 | + call_kwargs = dict( |
| 383 | + groupby=groupby, |
| 384 | + condition1=condition1, |
| 385 | + condition2=condition2, |
| 386 | + obsm_key=obsm_key, |
| 387 | + layer=layer, |
| 388 | + sample_col=sample_col, |
| 389 | + gp=gp, |
| 390 | + fdr=fdr, |
| 391 | + filter=filter_settings, |
| 392 | + storage=storage, |
| 393 | + output=output, |
| 394 | + **params, # remaining params forwarded as function_kwargs |
| 395 | + ) |
| 396 | + |
| 397 | + # Dry run: estimate resources, output JSON to stdout, report to stderr |
| 398 | + if args.dry_run: |
| 399 | + import io |
| 400 | + |
| 401 | + try: |
| 402 | + old_stdout = sys.stdout |
| 403 | + sys.stdout = sys.stderr # capture de()'s print(plan.format_report()) |
| 404 | + try: |
| 405 | + plan = de(adata, dry_run=True, **call_kwargs) |
| 406 | + finally: |
| 407 | + sys.stdout = old_stdout |
| 408 | + except Exception as e: |
| 409 | + logger.error(f"Dry run failed: {str(e)}") |
| 410 | + raise |
| 411 | + |
| 412 | + # Machine-parseable JSON to stdout |
| 413 | + json.dump(plan.to_dict(), sys.stdout, default=_json_default, indent=2) |
| 414 | + print(file=sys.stdout) # trailing newline |
| 415 | + sys.exit(0 if plan.is_feasible else 1) |
| 416 | + |
358 | 417 | # Run analysis |
359 | 418 | try: |
360 | | - result_dict = de( |
361 | | - adata, |
362 | | - groupby=groupby, |
363 | | - condition1=condition1, |
364 | | - condition2=condition2, |
365 | | - obsm_key=obsm_key, |
366 | | - layer=layer, |
367 | | - sample_col=sample_col, |
368 | | - gp=gp, |
369 | | - fdr=fdr, |
370 | | - filter=filter_settings, |
371 | | - storage=storage, |
372 | | - output=output, |
373 | | - **params, # remaining params forwarded as function_kwargs |
374 | | - ) |
| 419 | + result_dict = de(adata, **call_kwargs) |
375 | 420 | except Exception as e: |
376 | 421 | logger.error(f"Analysis failed: {str(e)}") |
377 | 422 | raise |
|
0 commit comments