|
| 1 | +import argparse |
1 | 2 | import os |
2 | 3 | import tempfile |
3 | 4 | from pathlib import Path |
|
7 | 8 | from codeflash.code_utils.config_parser import parse_config_file |
8 | 9 | from codeflash.code_utils.formatter import format_code, sort_imports |
9 | 10 |
|
| 11 | +from codeflash.discovery.functions_to_optimize import FunctionToOptimize |
| 12 | +from codeflash.optimization.function_optimizer import FunctionOptimizer |
| 13 | +from codeflash.verification.verification_utils import TestConfig |
10 | 14 |
|
11 | 15 | def test_remove_duplicate_imports(): |
12 | 16 | """Test that duplicate imports are removed when should_sort_imports is True.""" |
@@ -209,3 +213,255 @@ def foo(): |
209 | 213 | tmp_path = tmp.name |
210 | 214 | with pytest.raises(FileNotFoundError): |
211 | 215 | format_code(formatter_cmds=["exit 1"], path=Path(tmp_path)) |
| 216 | + |
| 217 | +############################################################ |
| 218 | +################ CST based formatting tests ################ |
| 219 | +############################################################ |
| 220 | +@pytest.fixture |
| 221 | +def setup_cst_formatter_args(): |
| 222 | + """Common setup for reformat_code_and_helpers tests.""" |
| 223 | + def _setup(unformatted_code, function_name): |
| 224 | + test_dir = Path(tempfile.mkdtemp()) |
| 225 | + target_path = test_dir / "target.py" |
| 226 | + target_path.write_text(unformatted_code, encoding="utf-8") |
| 227 | + |
| 228 | + function_to_optimize = FunctionToOptimize( |
| 229 | + function_name=function_name, |
| 230 | + parents=[], |
| 231 | + file_path=target_path |
| 232 | + ) |
| 233 | + |
| 234 | + test_cfg = TestConfig( |
| 235 | + tests_root=test_dir, |
| 236 | + project_root_path=test_dir, |
| 237 | + test_framework="pytest", |
| 238 | + tests_project_rootdir=test_dir, |
| 239 | + ) |
| 240 | + |
| 241 | + args = argparse.Namespace( |
| 242 | + disable_imports_sorting=False, |
| 243 | + formatter_cmds=[ |
| 244 | + "ruff check --exit-zero --fix $file", |
| 245 | + "ruff format $file" |
| 246 | + ], |
| 247 | + ) |
| 248 | + |
| 249 | + optimizer = FunctionOptimizer( |
| 250 | + function_to_optimize=function_to_optimize, |
| 251 | + test_cfg=test_cfg, |
| 252 | + args=args, |
| 253 | + ) |
| 254 | + |
| 255 | + return optimizer, target_path, function_to_optimize |
| 256 | + |
| 257 | + yield _setup |
| 258 | + |
| 259 | + |
| 260 | +def test_reformat_code_and_helpers(setup_cst_formatter_args): |
| 261 | + """ |
| 262 | + reformat_code_and_helpers should only format the code that is optimized not the whole file, to avoid large diffing |
| 263 | + """ |
| 264 | + unformatted_code = """import sys |
| 265 | +
|
| 266 | +
|
| 267 | +def lol(): |
| 268 | + print( "lol" ) |
| 269 | +
|
| 270 | +
|
| 271 | +
|
| 272 | +
|
| 273 | +class MyClass: |
| 274 | + def __init__(self, x=0): |
| 275 | + self.x = x |
| 276 | +
|
| 277 | + def lol(self): |
| 278 | + print( "lol" ) |
| 279 | +
|
| 280 | + def lol2 (self): |
| 281 | + print( " lol2" )""" |
| 282 | + |
| 283 | + expected_code = """import sys |
| 284 | +
|
| 285 | +
|
| 286 | +def lol(): |
| 287 | + print( "lol" ) |
| 288 | +
|
| 289 | +
|
| 290 | +
|
| 291 | +
|
| 292 | +class MyClass: |
| 293 | + def __init__(self, x=0): |
| 294 | + self.x = x |
| 295 | +
|
| 296 | + def lol(self): |
| 297 | + print( "lol" ) |
| 298 | +
|
| 299 | + def lol2(self): |
| 300 | + print(" lol2") |
| 301 | +""" |
| 302 | + |
| 303 | + optimizer, target_path, function_to_optimize = setup_cst_formatter_args( |
| 304 | + unformatted_code, "MyClass.lol2" |
| 305 | + ) |
| 306 | + |
| 307 | + formatted_code, _ = optimizer.reformat_code_and_helpers( |
| 308 | + helper_functions=[], |
| 309 | + path=target_path, |
| 310 | + original_code=optimizer.function_to_optimize_source_code, |
| 311 | + opt_func_name=function_to_optimize.function_name |
| 312 | + ) |
| 313 | + |
| 314 | + assert formatted_code == expected_code |
| 315 | + |
| 316 | + |
| 317 | +def test_reformat_code_and_helpers_with_duplicated_target_function_names(setup_cst_formatter_args): |
| 318 | + unformatted_code = """import sys |
| 319 | +def lol(): |
| 320 | + print( "lol" ) |
| 321 | +
|
| 322 | +class MyClass: |
| 323 | + def __init__(self, x=0): |
| 324 | + self.x = x |
| 325 | +
|
| 326 | + def lol(self): |
| 327 | + print( "lol" )""" |
| 328 | + |
| 329 | + expected_code = """import sys |
| 330 | +def lol(): |
| 331 | + print( "lol" ) |
| 332 | +
|
| 333 | +class MyClass: |
| 334 | + def __init__(self, x=0): |
| 335 | + self.x = x |
| 336 | +
|
| 337 | + def lol(self): |
| 338 | + print("lol") |
| 339 | +""" |
| 340 | + |
| 341 | + optimizer, target_path, function_to_optimize = setup_cst_formatter_args( |
| 342 | + unformatted_code, "MyClass.lol" |
| 343 | + ) |
| 344 | + |
| 345 | + formatted_code, _ = optimizer.reformat_code_and_helpers( |
| 346 | + helper_functions=[], |
| 347 | + path=target_path, |
| 348 | + original_code=optimizer.function_to_optimize_source_code, |
| 349 | + opt_func_name=function_to_optimize.function_name |
| 350 | + ) |
| 351 | + |
| 352 | + assert formatted_code == expected_code |
| 353 | + |
| 354 | + |
| 355 | + |
| 356 | +def test_formatting_nested_functions(setup_cst_formatter_args): |
| 357 | + unformatted_code = """def hello(): |
| 358 | + print("Hello") |
| 359 | + def nested_function() : |
| 360 | + print ("This is a nested function") |
| 361 | + def another_nested_function(): |
| 362 | + print ("This is another nested function")""" |
| 363 | + |
| 364 | + expected_code = """def hello(): |
| 365 | + print("Hello") |
| 366 | + def nested_function(): |
| 367 | + print("This is a nested function") |
| 368 | + def another_nested_function(): |
| 369 | + print ("This is another nested function")""" |
| 370 | + |
| 371 | + optimizer, target_path, function_to_optimize = setup_cst_formatter_args( |
| 372 | + unformatted_code, "hello.nested_function" |
| 373 | + ) |
| 374 | + |
| 375 | + formatted_code, _ = optimizer.reformat_code_and_helpers( |
| 376 | + helper_functions=[], |
| 377 | + path=target_path, |
| 378 | + original_code=optimizer.function_to_optimize_source_code, |
| 379 | + opt_func_name=function_to_optimize.function_name |
| 380 | + ) |
| 381 | + |
| 382 | + assert formatted_code == expected_code |
| 383 | + |
| 384 | + |
| 385 | +def test_formatting_standalone_functions(setup_cst_formatter_args): |
| 386 | + unformatted_code = """def func1 (): |
| 387 | + print( "This is a function with bad formatting") |
| 388 | +def func2() : |
| 389 | + print ( "This is another function with bad formatting" ) |
| 390 | +""" |
| 391 | + |
| 392 | + expected_code = """def func1 (): |
| 393 | + print( "This is a function with bad formatting") |
| 394 | +def func2(): |
| 395 | + print("This is another function with bad formatting") |
| 396 | +""" |
| 397 | + |
| 398 | + optimizer, target_path, function_to_optimize = setup_cst_formatter_args( |
| 399 | + unformatted_code, "func2" |
| 400 | + ) |
| 401 | + |
| 402 | + formatted_code, _ = optimizer.reformat_code_and_helpers( |
| 403 | + helper_functions=[], |
| 404 | + path=target_path, |
| 405 | + original_code=optimizer.function_to_optimize_source_code, |
| 406 | + opt_func_name=function_to_optimize.function_name |
| 407 | + ) |
| 408 | + |
| 409 | + assert formatted_code == expected_code |
| 410 | + |
| 411 | + |
| 412 | +def test_formatting_function_with_decorators(setup_cst_formatter_args): |
| 413 | + unformatted_code = """@decorator1 |
| 414 | +@decorator2( arg1 , arg2 ) |
| 415 | +def func1 (): |
| 416 | + print( "This is a function with bad formatting") |
| 417 | +
|
| 418 | +@another_decorator( arg) |
| 419 | +def func2 ( x,y ): |
| 420 | + print ( "This is another function with bad formatting" )""" |
| 421 | + |
| 422 | + expected_code = """@decorator1 |
| 423 | +@decorator2( arg1 , arg2 ) |
| 424 | +def func1 (): |
| 425 | + print( "This is a function with bad formatting") |
| 426 | +
|
| 427 | +@another_decorator(arg) |
| 428 | +def func2(x, y): |
| 429 | + print("This is another function with bad formatting") |
| 430 | +""" |
| 431 | + |
| 432 | + optimizer, target_path, function_to_optimize = setup_cst_formatter_args( |
| 433 | + unformatted_code, "func2" |
| 434 | + ) |
| 435 | + |
| 436 | + formatted_code, _ = optimizer.reformat_code_and_helpers( |
| 437 | + helper_functions=[], |
| 438 | + path=target_path, |
| 439 | + original_code=optimizer.function_to_optimize_source_code, |
| 440 | + opt_func_name=function_to_optimize.function_name |
| 441 | + ) |
| 442 | + |
| 443 | + assert formatted_code == expected_code |
| 444 | + |
| 445 | + |
| 446 | +def test_formatting_function_with_syntax_error(setup_cst_formatter_args): |
| 447 | + """shouldn't happen anyway, but just in case""" |
| 448 | + unformatted_code = """def func1(): |
| 449 | + print("This is a function with a syntax error" |
| 450 | +def func2(): |
| 451 | + print("This is another function with a syntax error") |
| 452 | +""" |
| 453 | + |
| 454 | + expected_code = unformatted_code # No formatting should be applied due to syntax error |
| 455 | + |
| 456 | + optimizer, target_path, function_to_optimize = setup_cst_formatter_args( |
| 457 | + unformatted_code, "func2" |
| 458 | + ) |
| 459 | + |
| 460 | + formatted_code, _ = optimizer.reformat_code_and_helpers( |
| 461 | + helper_functions=[], |
| 462 | + path=target_path, |
| 463 | + original_code=optimizer.function_to_optimize_source_code, |
| 464 | + opt_func_name=function_to_optimize.function_name |
| 465 | + ) |
| 466 | + |
| 467 | + assert formatted_code == expected_code |
0 commit comments