Skip to content

Commit cd9f0f4

Browse files
tests: edge cases for cst formatting
1 parent 98454d0 commit cd9f0f4

4 files changed

Lines changed: 277 additions & 102 deletions

File tree

codeflash/discovery/functions_to_optimize.py

Lines changed: 13 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -80,27 +80,29 @@ def visit_FunctionDef(self, node: cst.FunctionDef) -> None:
8080
ending_line=pos.end.line,
8181
)
8282
)
83-
8483
class CodeRangeFunctionVisitor(cst.CSTVisitor):
85-
METADATA_DEPENDENCIES = (cst.metadata.PositionProvider, cst.metadata.QualifiedNameProvider)
86-
84+
METADATA_DEPENDENCIES = (
85+
cst.metadata.PositionProvider,
86+
cst.metadata.QualifiedNameProvider,
87+
)
88+
8789
def __init__(self, target_function_name: str) -> None:
8890
super().__init__()
8991
self.target_func = target_function_name
90-
self.current_path = []
9192
self.start_line: Optional[int] = None
9293
self.end_line: Optional[int] = None
9394

9495
def visit_FunctionDef(self, node: cst.FunctionDef) -> bool:
95-
qualified_names = {
96-
str(qn.name) for qn in
96+
qualified_names = [
97+
str(qn.name).replace(".<locals>", "") for qn in
9798
self.get_metadata(cst.metadata.QualifiedNameProvider, node)
98-
}
99-
99+
]
100100
if self.target_func in qualified_names:
101-
position = self.get_metadata(cst.metadata.PositionProvider, node)
102-
self.start_line = position.start.line
103-
self.end_line = position.end.line
101+
func_position = self.get_metadata(cst.metadata.PositionProvider, node)
102+
decorators_count = len(node.decorators)
103+
self.start_line = func_position.start.line - decorators_count
104+
self.end_line = func_position.end.line
105+
return False
104106

105107

106108
class FunctionWithReturnStatement(ast.NodeVisitor):

codeflash/optimization/function_optimizer.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -302,13 +302,6 @@ def optimize_function(self) -> Result[BestOptimization, str]: # noqa: PLR0911
302302
code_context=code_context, optimized_code=best_optimization.candidate.source_code
303303
)
304304

305-
print("file_path_to_helper_classes\n", file_path_to_helper_classes)
306-
filepaths_to_inspect = [
307-
self.function_to_optimize.file_path,
308-
*list({helper.file_path for helper in code_context.helper_functions}),
309-
]
310-
print("filepaths_to_inspect\n", filepaths_to_inspect)
311-
312305
new_code, new_helper_code = self.reformat_code_and_helpers(
313306
code_context.helper_functions, explanation.file_path, self.function_to_optimize_source_code,
314307
opt_func_name=explanation.function_name
@@ -610,13 +603,20 @@ def reformat_code_and_helpers(
610603
should_sort_imports = False
611604

612605
whole_file_content = path.read_text(encoding="utf8")
613-
wrapper = cst.metadata.MetadataWrapper(cst.parse_module(whole_file_content))
606+
wrapper: cst.metadata.MetadataWrapper | None = None
607+
try:
608+
wrapper = cst.metadata.MetadataWrapper(cst.parse_module(whole_file_content))
609+
except cst.ParserSyntaxError as e:
610+
logger.error(f"Syntax error detected, aborting reformatting.")
611+
return original_code, {}
612+
614613
visitor = CodeRangeFunctionVisitor(target_function_name=opt_func_name)
615614
wrapper.visit(visitor)
616615

617616
lines = whole_file_content.splitlines(keepends=True)
618617
if visitor.start_line == None:
619618
logger.error(f"Could not find function {opt_func_name} in {path}, aborting reformatting.")
619+
return original_code, {}
620620
else:
621621
opt_func_source_lines = lines[visitor.start_line-1:visitor.end_line]
622622

tests/test_formatter.py

Lines changed: 256 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import argparse
12
import os
23
import tempfile
34
from pathlib import Path
@@ -7,6 +8,9 @@
78
from codeflash.code_utils.config_parser import parse_config_file
89
from codeflash.code_utils.formatter import format_code, sort_imports
910

11+
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
12+
from codeflash.optimization.function_optimizer import FunctionOptimizer
13+
from codeflash.verification.verification_utils import TestConfig
1014

1115
def test_remove_duplicate_imports():
1216
"""Test that duplicate imports are removed when should_sort_imports is True."""
@@ -209,3 +213,255 @@ def foo():
209213
tmp_path = tmp.name
210214
with pytest.raises(FileNotFoundError):
211215
format_code(formatter_cmds=["exit 1"], path=Path(tmp_path))
216+
217+
############################################################
218+
################ CST based formatting tests ################
219+
############################################################
220+
@pytest.fixture
221+
def setup_cst_formatter_args():
222+
"""Common setup for reformat_code_and_helpers tests."""
223+
def _setup(unformatted_code, function_name):
224+
test_dir = Path(tempfile.mkdtemp())
225+
target_path = test_dir / "target.py"
226+
target_path.write_text(unformatted_code, encoding="utf-8")
227+
228+
function_to_optimize = FunctionToOptimize(
229+
function_name=function_name,
230+
parents=[],
231+
file_path=target_path
232+
)
233+
234+
test_cfg = TestConfig(
235+
tests_root=test_dir,
236+
project_root_path=test_dir,
237+
test_framework="pytest",
238+
tests_project_rootdir=test_dir,
239+
)
240+
241+
args = argparse.Namespace(
242+
disable_imports_sorting=False,
243+
formatter_cmds=[
244+
"ruff check --exit-zero --fix $file",
245+
"ruff format $file"
246+
],
247+
)
248+
249+
optimizer = FunctionOptimizer(
250+
function_to_optimize=function_to_optimize,
251+
test_cfg=test_cfg,
252+
args=args,
253+
)
254+
255+
return optimizer, target_path, function_to_optimize
256+
257+
yield _setup
258+
259+
260+
def test_reformat_code_and_helpers(setup_cst_formatter_args):
261+
"""
262+
reformat_code_and_helpers should only format the code that is optimized not the whole file, to avoid large diffing
263+
"""
264+
unformatted_code = """import sys
265+
266+
267+
def lol():
268+
print( "lol" )
269+
270+
271+
272+
273+
class MyClass:
274+
def __init__(self, x=0):
275+
self.x = x
276+
277+
def lol(self):
278+
print( "lol" )
279+
280+
def lol2 (self):
281+
print( " lol2" )"""
282+
283+
expected_code = """import sys
284+
285+
286+
def lol():
287+
print( "lol" )
288+
289+
290+
291+
292+
class MyClass:
293+
def __init__(self, x=0):
294+
self.x = x
295+
296+
def lol(self):
297+
print( "lol" )
298+
299+
def lol2(self):
300+
print(" lol2")
301+
"""
302+
303+
optimizer, target_path, function_to_optimize = setup_cst_formatter_args(
304+
unformatted_code, "MyClass.lol2"
305+
)
306+
307+
formatted_code, _ = optimizer.reformat_code_and_helpers(
308+
helper_functions=[],
309+
path=target_path,
310+
original_code=optimizer.function_to_optimize_source_code,
311+
opt_func_name=function_to_optimize.function_name
312+
)
313+
314+
assert formatted_code == expected_code
315+
316+
317+
def test_reformat_code_and_helpers_with_duplicated_target_function_names(setup_cst_formatter_args):
318+
unformatted_code = """import sys
319+
def lol():
320+
print( "lol" )
321+
322+
class MyClass:
323+
def __init__(self, x=0):
324+
self.x = x
325+
326+
def lol(self):
327+
print( "lol" )"""
328+
329+
expected_code = """import sys
330+
def lol():
331+
print( "lol" )
332+
333+
class MyClass:
334+
def __init__(self, x=0):
335+
self.x = x
336+
337+
def lol(self):
338+
print("lol")
339+
"""
340+
341+
optimizer, target_path, function_to_optimize = setup_cst_formatter_args(
342+
unformatted_code, "MyClass.lol"
343+
)
344+
345+
formatted_code, _ = optimizer.reformat_code_and_helpers(
346+
helper_functions=[],
347+
path=target_path,
348+
original_code=optimizer.function_to_optimize_source_code,
349+
opt_func_name=function_to_optimize.function_name
350+
)
351+
352+
assert formatted_code == expected_code
353+
354+
355+
356+
def test_formatting_nested_functions(setup_cst_formatter_args):
357+
unformatted_code = """def hello():
358+
print("Hello")
359+
def nested_function() :
360+
print ("This is a nested function")
361+
def another_nested_function():
362+
print ("This is another nested function")"""
363+
364+
expected_code = """def hello():
365+
print("Hello")
366+
def nested_function():
367+
print("This is a nested function")
368+
def another_nested_function():
369+
print ("This is another nested function")"""
370+
371+
optimizer, target_path, function_to_optimize = setup_cst_formatter_args(
372+
unformatted_code, "hello.nested_function"
373+
)
374+
375+
formatted_code, _ = optimizer.reformat_code_and_helpers(
376+
helper_functions=[],
377+
path=target_path,
378+
original_code=optimizer.function_to_optimize_source_code,
379+
opt_func_name=function_to_optimize.function_name
380+
)
381+
382+
assert formatted_code == expected_code
383+
384+
385+
def test_formatting_standalone_functions(setup_cst_formatter_args):
386+
unformatted_code = """def func1 ():
387+
print( "This is a function with bad formatting")
388+
def func2() :
389+
print ( "This is another function with bad formatting" )
390+
"""
391+
392+
expected_code = """def func1 ():
393+
print( "This is a function with bad formatting")
394+
def func2():
395+
print("This is another function with bad formatting")
396+
"""
397+
398+
optimizer, target_path, function_to_optimize = setup_cst_formatter_args(
399+
unformatted_code, "func2"
400+
)
401+
402+
formatted_code, _ = optimizer.reformat_code_and_helpers(
403+
helper_functions=[],
404+
path=target_path,
405+
original_code=optimizer.function_to_optimize_source_code,
406+
opt_func_name=function_to_optimize.function_name
407+
)
408+
409+
assert formatted_code == expected_code
410+
411+
412+
def test_formatting_function_with_decorators(setup_cst_formatter_args):
413+
unformatted_code = """@decorator1
414+
@decorator2( arg1 , arg2 )
415+
def func1 ():
416+
print( "This is a function with bad formatting")
417+
418+
@another_decorator( arg)
419+
def func2 ( x,y ):
420+
print ( "This is another function with bad formatting" )"""
421+
422+
expected_code = """@decorator1
423+
@decorator2( arg1 , arg2 )
424+
def func1 ():
425+
print( "This is a function with bad formatting")
426+
427+
@another_decorator(arg)
428+
def func2(x, y):
429+
print("This is another function with bad formatting")
430+
"""
431+
432+
optimizer, target_path, function_to_optimize = setup_cst_formatter_args(
433+
unformatted_code, "func2"
434+
)
435+
436+
formatted_code, _ = optimizer.reformat_code_and_helpers(
437+
helper_functions=[],
438+
path=target_path,
439+
original_code=optimizer.function_to_optimize_source_code,
440+
opt_func_name=function_to_optimize.function_name
441+
)
442+
443+
assert formatted_code == expected_code
444+
445+
446+
def test_formatting_function_with_syntax_error(setup_cst_formatter_args):
447+
"""shouldn't happen anyway, but just in case"""
448+
unformatted_code = """def func1():
449+
print("This is a function with a syntax error"
450+
def func2():
451+
print("This is another function with a syntax error")
452+
"""
453+
454+
expected_code = unformatted_code # No formatting should be applied due to syntax error
455+
456+
optimizer, target_path, function_to_optimize = setup_cst_formatter_args(
457+
unformatted_code, "func2"
458+
)
459+
460+
formatted_code, _ = optimizer.reformat_code_and_helpers(
461+
helper_functions=[],
462+
path=target_path,
463+
original_code=optimizer.function_to_optimize_source_code,
464+
opt_func_name=function_to_optimize.function_name
465+
)
466+
467+
assert formatted_code == expected_code

0 commit comments

Comments
 (0)