|
13 | 13 | import isort |
14 | 14 |
|
15 | 15 | from codeflash.cli_cmds.console import console, logger |
| 16 | +from codeflash.languages.registry import get_language_support |
16 | 17 | from codeflash.lsp.helpers import is_LSP_enabled |
17 | 18 |
|
18 | 19 |
|
@@ -47,8 +48,9 @@ def apply_formatter_cmds( |
47 | 48 | raise FileNotFoundError(msg) |
48 | 49 |
|
49 | 50 | file_path = path |
| 51 | + lang_support = get_language_support(path) |
50 | 52 | if test_dir_str: |
51 | | - file_path = Path(test_dir_str) / "temp.py" |
| 53 | + file_path = Path(test_dir_str) / ("temp" + lang_support.default_file_extension) |
52 | 54 | shutil.copy2(path, file_path) |
53 | 55 |
|
54 | 56 | file_token = "$file" # noqa: S105 |
@@ -87,13 +89,14 @@ def is_diff_line(line: str) -> bool: |
87 | 89 | return len(diff_lines) |
88 | 90 |
|
89 | 91 |
|
90 | | -def format_generated_code(generated_test_source: str, formatter_cmds: list[str]) -> str: |
| 92 | +def format_generated_code(generated_test_source: str, formatter_cmds: list[str], language: str = "python") -> str: |
91 | 93 | formatter_name = formatter_cmds[0].lower() if formatter_cmds else "disabled" |
92 | 94 | if formatter_name == "disabled": # nothing to do if no formatter provided |
93 | 95 | return re.sub(r"\n{2,}", "\n\n", generated_test_source) |
94 | 96 | with tempfile.TemporaryDirectory() as test_dir_str: |
95 | 97 | # try running formatter, if nothing changes (could be due to formatting failing or no actual formatting needed) return code with 2 or more newlines substituted with 2 newlines |
96 | | - original_temp = Path(test_dir_str) / "original_temp.py" |
| 98 | + lang_support = get_language_support(language) |
| 99 | + original_temp = Path(test_dir_str) / ("original_temp" + lang_support.default_file_extension) |
97 | 100 | original_temp.write_text(generated_test_source, encoding="utf8") |
98 | 101 | _, formatted_code, changed = apply_formatter_cmds( |
99 | 102 | formatter_cmds, original_temp, test_dir_str, print_status=False, exit_on_failure=False |
@@ -130,7 +133,8 @@ def format_code( |
130 | 133 | # we don't count the formatting diff for the optimized function as it should be well-formatted |
131 | 134 | original_code_without_opfunc = original_code.replace(optimized_code, "") |
132 | 135 |
|
133 | | - original_temp = Path(test_dir_str) / "original_temp.py" |
| 136 | + lang_support = get_language_support(path) |
| 137 | + original_temp = Path(test_dir_str) / ("original_temp" + lang_support.default_file_extension) |
134 | 138 | original_temp.write_text(original_code_without_opfunc, encoding="utf8") |
135 | 139 |
|
136 | 140 | formatted_temp, formatted_code, changed = apply_formatter_cmds( |
@@ -160,6 +164,7 @@ def format_code( |
160 | 164 | _, formatted_code, changed = apply_formatter_cmds( |
161 | 165 | formatter_cmds, path, test_dir_str=None, print_status=print_status, exit_on_failure=exit_on_failure |
162 | 166 | ) |
| 167 | + |
163 | 168 | if not changed: |
164 | 169 | logger.warning( |
165 | 170 | f"No changes detected in {path} after formatting, are you sure you have valid formatter commands?" |
|
0 commit comments