|
32 | 32 | from argparse import Namespace |
33 | 33 |
|
34 | 34 |
|
| 35 | +def detect_language_from_config(config: dict) -> str: |
| 36 | + """Detect the project language from config or file extensions. |
| 37 | +
|
| 38 | + Args: |
| 39 | + config: Project configuration dictionary. |
| 40 | +
|
| 41 | + Returns: |
| 42 | + Language identifier ('python', 'javascript', or 'typescript'). |
| 43 | +
|
| 44 | + """ |
| 45 | + # Check explicit language in config |
| 46 | + if "language" in config: |
| 47 | + return config["language"].lower() |
| 48 | + |
| 49 | + # Check module root for file types |
| 50 | + module_root = Path(config.get("module_root", ".")) |
| 51 | + if module_root.exists(): |
| 52 | + js_files = list(module_root.glob("**/*.js")) + list(module_root.glob("**/*.jsx")) |
| 53 | + ts_files = list(module_root.glob("**/*.ts")) + list(module_root.glob("**/*.tsx")) |
| 54 | + py_files = list(module_root.glob("**/*.py")) |
| 55 | + |
| 56 | + # Filter out node_modules |
| 57 | + js_files = [f for f in js_files if "node_modules" not in str(f)] |
| 58 | + ts_files = [f for f in ts_files if "node_modules" not in str(f)] |
| 59 | + |
| 60 | + total_js = len(js_files) + len(ts_files) |
| 61 | + total_py = len(py_files) |
| 62 | + |
| 63 | + if total_js > total_py: |
| 64 | + return "typescript" if len(ts_files) > len(js_files) else "javascript" |
| 65 | + |
| 66 | + return "python" |
| 67 | + |
| 68 | + |
35 | 69 | def main(args: Namespace | None = None) -> ArgumentParser: |
36 | 70 | parser = ArgumentParser(allow_abbrev=False) |
37 | 71 | parser.add_argument("-o", "--outfile", dest="outfile", help="Save trace to <outfile>", default="codeflash.trace") |
@@ -60,6 +94,11 @@ def main(args: Namespace | None = None) -> ArgumentParser: |
60 | 94 | parser.add_argument( |
61 | 95 | "--limit", type=int, default=None, help="Limit the number of test files to process (for -m pytest mode)" |
62 | 96 | ) |
| 97 | + parser.add_argument( |
| 98 | + "--language", |
| 99 | + help="Language to trace (python, javascript, typescript). Auto-detected if not specified.", |
| 100 | + default=None, |
| 101 | + ) |
63 | 102 |
|
64 | 103 | if args is not None: |
65 | 104 | parsed_args = args |
@@ -93,6 +132,14 @@ def main(args: Namespace | None = None) -> ArgumentParser: |
93 | 132 | outfile = parsed_args.outfile |
94 | 133 | config, found_config_path = parse_config_file(parsed_args.codeflash_config) |
95 | 134 | project_root = project_root_from_module_root(Path(config["module_root"]), found_config_path) |
| 135 | + |
| 136 | + # Detect or use specified language |
| 137 | + language = getattr(parsed_args, "language", None) or detect_language_from_config(config) |
| 138 | + |
| 139 | + # Route to appropriate tracer based on language |
| 140 | + if language in ("javascript", "typescript"): |
| 141 | + return run_javascript_tracer_main(parsed_args, config, project_root, outfile, unknown_args) |
| 142 | + |
96 | 143 | if len(unknown_args) > 0: |
97 | 144 | args_dict = { |
98 | 145 | "functions": parsed_args.only_functions, |
@@ -255,5 +302,89 @@ def main(args: Namespace | None = None) -> ArgumentParser: |
255 | 302 | return parser |
256 | 303 |
|
257 | 304 |
|
| 305 | +def run_javascript_tracer_main( |
| 306 | + parsed_args: Namespace, config: dict, project_root: Path, outfile: Path, unknown_args: list[str] |
| 307 | +) -> ArgumentParser: |
| 308 | + """Run the JavaScript tracer. |
| 309 | +
|
| 310 | + Args: |
| 311 | + parsed_args: Parsed command line arguments. |
| 312 | + config: Project configuration. |
| 313 | + project_root: Project root directory. |
| 314 | + outfile: Output trace file path. |
| 315 | + unknown_args: Remaining command line arguments. |
| 316 | +
|
| 317 | + Returns: |
| 318 | + The argument parser. |
| 319 | +
|
| 320 | + """ |
| 321 | + from codeflash.languages.javascript.tracer_runner import ( |
| 322 | + check_javascript_tracer_available, |
| 323 | + get_tracer_requirements_message, |
| 324 | + run_javascript_tracer, |
| 325 | + ) |
| 326 | + |
| 327 | + # Check requirements |
| 328 | + if not check_javascript_tracer_available(): |
| 329 | + console.print(f"[red]{get_tracer_requirements_message()}[/red]") |
| 330 | + sys.exit(1) |
| 331 | + |
| 332 | + # Prepare args for the tracer runner |
| 333 | + parsed_args.script_args = unknown_args |
| 334 | + |
| 335 | + # Run the tracer |
| 336 | + console.print("[bold blue]Running JavaScript tracer...[/bold blue]") |
| 337 | + result = run_javascript_tracer(parsed_args, config, project_root) |
| 338 | + |
| 339 | + if not result["success"]: |
| 340 | + console.print(f"[red]Tracing failed: {result.get('error', 'Unknown error')}[/red]") |
| 341 | + sys.exit(1) |
| 342 | + |
| 343 | + console.print(f"[green]Trace saved to: {result['trace_file']}[/green]") |
| 344 | + |
| 345 | + if result.get("replay_test_file"): |
| 346 | + console.print(f"[green]Replay test generated: {result['replay_test_file']}[/green]") |
| 347 | + |
| 348 | + # Run optimization if not trace-only mode |
| 349 | + if not parsed_args.trace_only: |
| 350 | + from codeflash.cli_cmds.cli import parse_args as cli_parse_args |
| 351 | + from codeflash.cli_cmds.cli import process_pyproject_config |
| 352 | + from codeflash.cli_cmds.cmd_init import CODEFLASH_LOGO |
| 353 | + from codeflash.cli_cmds.console import paneled_text |
| 354 | + from codeflash.languages import set_current_language |
| 355 | + from codeflash.languages.base import Language |
| 356 | + from codeflash.telemetry import posthog_cf |
| 357 | + from codeflash.telemetry.sentry import init_sentry |
| 358 | + |
| 359 | + # Set language to JavaScript |
| 360 | + set_current_language(Language.JAVASCRIPT) |
| 361 | + |
| 362 | + sys.argv = ["codeflash", "--replay-test", result["replay_test_file"]] |
| 363 | + args = cli_parse_args() |
| 364 | + paneled_text( |
| 365 | + CODEFLASH_LOGO, |
| 366 | + panel_args={"title": "https://codeflash.ai", "expand": False}, |
| 367 | + text_args={"style": "bold gold3"}, |
| 368 | + ) |
| 369 | + |
| 370 | + args = process_pyproject_config(args) |
| 371 | + args.previous_checkpoint_functions = None |
| 372 | + init_sentry(enabled=not args.disable_telemetry, exclude_errors=True) |
| 373 | + posthog_cf.initialize_posthog(enabled=not args.disable_telemetry) |
| 374 | + |
| 375 | + from codeflash.optimization import optimizer |
| 376 | + |
| 377 | + args.effort = EffortLevel.HIGH.value |
| 378 | + optimizer.run_with_args(args) |
| 379 | + |
| 380 | + # Clean up trace and replay test files |
| 381 | + if outfile: |
| 382 | + outfile.unlink(missing_ok=True) |
| 383 | + Path(result["replay_test_file"]).unlink(missing_ok=True) |
| 384 | + |
| 385 | + # Return a new parser for API compatibility |
| 386 | + return ArgumentParser(allow_abbrev=False) |
| 387 | + |
| 388 | + |
258 | 389 | if __name__ == "__main__": |
259 | 390 | main() |
0 commit comments