Skip to content

Commit afc4941

Browse files
Merge pull request #1230 from codeflash-ai/fix/get-language-based-on-formatter
[FIX] Determine language based on common formatters
2 parents 4803f26 + d0ec9b3 commit afc4941

6 files changed

Lines changed: 77 additions & 7 deletions

File tree

codeflash/code_utils/env_utils.py

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,10 +13,14 @@
1313
from codeflash.code_utils.code_utils import exit_with_message
1414
from codeflash.code_utils.formatter import format_code
1515
from codeflash.code_utils.shell_utils import read_api_key_from_shell_config, save_api_key_to_rc
16+
from codeflash.languages.base import Language
17+
from codeflash.languages.registry import get_language_support_by_common_formatters
1618
from codeflash.lsp.helpers import is_LSP_enabled
1719

1820

19-
def check_formatter_installed(formatter_cmds: list[str], exit_on_failure: bool = True) -> bool:
21+
def check_formatter_installed(
22+
formatter_cmds: list[str], exit_on_failure: bool = True, language: str = "python"
23+
) -> bool:
2024
if not formatter_cmds or formatter_cmds[0] == "disabled":
2125
return True
2226
first_cmd = formatter_cmds[0]
@@ -35,10 +39,21 @@ def check_formatter_installed(formatter_cmds: list[str], exit_on_failure: bool =
3539
)
3640
return False
3741

38-
tmp_code = """print("hello world")"""
42+
lang_support = get_language_support_by_common_formatters(formatter_cmds)
43+
if not lang_support:
44+
logger.debug(f"Could not determine language for formatter: {formatter_cmds}")
45+
return True
46+
47+
if lang_support.language == Language.PYTHON:
48+
tmp_code = """print("hello world")"""
49+
elif lang_support.language in (Language.JAVASCRIPT, Language.TYPESCRIPT):
50+
tmp_code = "console.log('hello world');"
51+
else:
52+
return True
53+
3954
try:
4055
with tempfile.TemporaryDirectory() as tmpdir:
41-
tmp_file = Path(tmpdir) / "test_codeflash_formatter.py"
56+
tmp_file = Path(tmpdir) / ("test_codeflash_formatter" + lang_support.default_file_extension)
4257
tmp_file.write_text(tmp_code, encoding="utf-8")
4358
format_code(formatter_cmds, tmp_file, print_status=False, exit_on_failure=False)
4459
return True

codeflash/code_utils/formatter.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
import isort
1414

1515
from codeflash.cli_cmds.console import console, logger
16+
from codeflash.languages.registry import get_language_support
1617
from codeflash.lsp.helpers import is_LSP_enabled
1718

1819

@@ -47,8 +48,9 @@ def apply_formatter_cmds(
4748
raise FileNotFoundError(msg)
4849

4950
file_path = path
51+
lang_support = get_language_support(path)
5052
if test_dir_str:
51-
file_path = Path(test_dir_str) / "temp.py"
53+
file_path = Path(test_dir_str) / ("temp" + lang_support.default_file_extension)
5254
shutil.copy2(path, file_path)
5355

5456
file_token = "$file" # noqa: S105
@@ -87,13 +89,14 @@ def is_diff_line(line: str) -> bool:
8789
return len(diff_lines)
8890

8991

90-
def format_generated_code(generated_test_source: str, formatter_cmds: list[str]) -> str:
92+
def format_generated_code(generated_test_source: str, formatter_cmds: list[str], language: str = "python") -> str:
9193
formatter_name = formatter_cmds[0].lower() if formatter_cmds else "disabled"
9294
if formatter_name == "disabled": # nothing to do if no formatter provided
9395
return re.sub(r"\n{2,}", "\n\n", generated_test_source)
9496
with tempfile.TemporaryDirectory() as test_dir_str:
9597
# try running formatter, if nothing changes (could be due to formatting failing or no actual formatting needed) return code with 2 or more newlines substituted with 2 newlines
96-
original_temp = Path(test_dir_str) / "original_temp.py"
98+
lang_support = get_language_support(language)
99+
original_temp = Path(test_dir_str) / ("original_temp" + lang_support.default_file_extension)
97100
original_temp.write_text(generated_test_source, encoding="utf8")
98101
_, formatted_code, changed = apply_formatter_cmds(
99102
formatter_cmds, original_temp, test_dir_str, print_status=False, exit_on_failure=False
@@ -130,7 +133,8 @@ def format_code(
130133
# we don't count the formatting diff for the optimized function as it should be well-formatted
131134
original_code_without_opfunc = original_code.replace(optimized_code, "")
132135

133-
original_temp = Path(test_dir_str) / "original_temp.py"
136+
lang_support = get_language_support(path)
137+
original_temp = Path(test_dir_str) / ("original_temp" + lang_support.default_file_extension)
134138
original_temp.write_text(original_code_without_opfunc, encoding="utf8")
135139

136140
formatted_temp, formatted_code, changed = apply_formatter_cmds(
@@ -160,6 +164,7 @@ def format_code(
160164
_, formatted_code, changed = apply_formatter_cmds(
161165
formatter_cmds, path, test_dir_str=None, print_status=print_status, exit_on_failure=exit_on_failure
162166
)
167+
163168
if not changed:
164169
logger.warning(
165170
f"No changes detected in {path} after formatting, are you sure you have valid formatter commands?"

codeflash/languages/base.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -278,6 +278,11 @@ def file_extensions(self) -> tuple[str, ...]:
278278
"""
279279
...
280280

281+
@property
282+
def default_file_extension(self) -> str:
283+
"""Default file extension for this language."""
284+
...
285+
281286
@property
282287
def test_framework(self) -> str:
283288
"""Primary test framework name.

codeflash/languages/javascript/support.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,11 @@ def file_extensions(self) -> tuple[str, ...]:
5353
"""File extensions supported by JavaScript."""
5454
return (".js", ".jsx", ".mjs", ".cjs")
5555

56+
@property
57+
def default_file_extension(self) -> str:
58+
"""Default file extension for JavaScript."""
59+
return ".js"
60+
5661
@property
5762
def test_framework(self) -> str:
5863
"""Primary test framework for JavaScript."""

codeflash/languages/python/support.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,11 @@ def file_extensions(self) -> tuple[str, ...]:
4545
"""File extensions supported by Python."""
4646
return (".py", ".pyw")
4747

48+
@property
49+
def default_file_extension(self) -> str:
50+
"""Default file extension for Python."""
51+
return ".py"
52+
4853
@property
4954
def test_framework(self) -> str:
5055
"""Primary test framework for Python."""

codeflash/languages/registry.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -178,6 +178,41 @@ def get_language_support(identifier: Path | Language | str) -> LanguageSupport:
178178
_FRAMEWORK_CACHE: dict[str, LanguageSupport] = {}
179179

180180

181+
def get_language_support_by_common_formatters(formatter_cmd: str | list[str]) -> LanguageSupport | None:
182+
language: Language | None = None
183+
if isinstance(formatter_cmd, str):
184+
formatter_cmd = [formatter_cmd]
185+
186+
if len(formatter_cmd) == 1:
187+
formatter_cmd = formatter_cmd[0].split(" ")
188+
189+
# Try as extension first
190+
ext = None
191+
192+
py_formatters = ["black", "isort", "ruff", "autopep8", "yapf", "pyfmt"]
193+
js_ts_formatters = ["prettier", "eslint", "biome", "rome", "deno", "standard", "tslint"]
194+
195+
if any(cmd in py_formatters for cmd in formatter_cmd):
196+
ext = ".py"
197+
elif any(cmd in js_ts_formatters for cmd in formatter_cmd):
198+
ext = ".js"
199+
200+
if ext is None:
201+
# can't determine language
202+
return None
203+
204+
cls = _EXTENSION_REGISTRY[ext]
205+
language = cls().language
206+
207+
# Return cached instance or create new one
208+
if language not in _SUPPORT_CACHE:
209+
if language not in _LANGUAGE_REGISTRY:
210+
raise UnsupportedLanguageError(str(language), get_supported_languages())
211+
_SUPPORT_CACHE[language] = _LANGUAGE_REGISTRY[language]()
212+
213+
return _SUPPORT_CACHE[language]
214+
215+
181216
def get_language_support_by_framework(test_framework: str) -> LanguageSupport | None:
182217
"""Get language support for a test framework.
183218

0 commit comments

Comments
 (0)