Skip to content

Commit 76d2e12

Browse files
authored
Merge branch 'main' into fix/cf_setup
2 parents 7f7cc62 + 2cee81b commit 76d2e12

31 files changed

Lines changed: 1127 additions & 79 deletions

CLAUDE.md

Lines changed: 4 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -57,19 +57,11 @@ codeflash/
5757
└── result/ # Result types and handling
5858
```
5959

60-
### Key Patterns
61-
62-
**Either/Result pattern for errors:**
63-
```python
64-
from codeflash.either import is_successful, Success, Failure
65-
result = some_operation()
66-
if is_successful(result):
67-
value = result.unwrap()
68-
else:
69-
error = result.failure()
70-
```
60+
### Key Rules to follow
7161

72-
**Use libcst, not ast** - Always use `libcst` for code parsing/modification to preserve formatting.
62+
- Use libcst, not ast - For Python, always use `libcst` for code parsing/modification to preserve formatting.
63+
- Code context extraction and replacement tests must always assert for full string equality, no substring matching.
64+
- Any new feature or bug fix that can be tested automatically must have test cases.
7365

7466
## Code Style
7567

codeflash/benchmarking/plugin/plugin.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -200,7 +200,7 @@ def get_benchmark_timings(trace_path: Path) -> dict[BenchmarkKey, int]:
200200

201201
# Pytest hooks
202202
@pytest.hookimpl
203-
def pytest_sessionfinish(self, session, exitstatus) -> None: # noqa: ANN001
203+
def pytest_sessionfinish(self, session, exitstatus) -> None:
204204
"""Execute after whole test run is completed."""
205205
# Write any remaining benchmark timings to the database
206206
codeflash_trace.close()
@@ -236,20 +236,20 @@ class Benchmark: # noqa: D106
236236
def __init__(self, request: pytest.FixtureRequest) -> None:
237237
self.request = request
238238

239-
def __call__(self, func, *args, **kwargs): # noqa: ANN001, ANN002, ANN003, ANN204
239+
def __call__(self, func, *args, **kwargs): # noqa: ANN002, ANN003, ANN204
240240
"""Handle both direct function calls and decorator usage."""
241241
if args or kwargs:
242242
# Used as benchmark(func, *args, **kwargs)
243243
return self._run_benchmark(func, *args, **kwargs)
244244

245245
# Used as @benchmark decorator
246-
def wrapped_func(*args, **kwargs): # noqa: ANN002, ANN003, ANN202
246+
def wrapped_func(*args, **kwargs): # noqa: ANN002, ANN003
247247
return func(*args, **kwargs)
248248

249249
self._run_benchmark(func)
250250
return wrapped_func
251251

252-
def _run_benchmark(self, func, *args, **kwargs): # noqa: ANN002, ANN003, ANN202
252+
def _run_benchmark(self, func, *args, **kwargs): # noqa: ANN002, ANN003
253253
"""Actual benchmark implementation."""
254254
node_path = getattr(self.request.node, "path", None) or getattr(self.request.node, "fspath", None)
255255
if node_path is None:

codeflash/cli_cmds/cli.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -231,11 +231,22 @@ def process_pyproject_config(args: Namespace) -> Namespace:
231231
is_js_ts_project = pyproject_config.get("language") in ("javascript", "typescript")
232232
if args.tests_root is None:
233233
if is_js_ts_project:
234-
# Try common JS test directories, or default to module_root
234+
# Try common JS test directories at project root first
235235
for test_dir in ["test", "tests", "__tests__"]:
236236
if Path(test_dir).is_dir():
237237
args.tests_root = test_dir
238238
break
239+
# If not found at project root, try inside module_root (e.g., src/test, src/__tests__)
240+
if args.tests_root is None and args.module_root:
241+
module_root_path = Path(args.module_root)
242+
for test_dir in ["test", "tests", "__tests__"]:
243+
test_path = module_root_path / test_dir
244+
if test_path.is_dir():
245+
args.tests_root = str(test_path)
246+
break
247+
# Final fallback: default to module_root
248+
# Note: This may cause issues if tests are colocated with source files
249+
# In such cases, the user should explicitly configure testsRoot in package.json
239250
if args.tests_root is None:
240251
args.tests_root = args.module_root
241252
else:

codeflash/cli_cmds/console.py

Lines changed: 25 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -129,6 +129,9 @@ def code_print(
129129

130130
spinners = cycle(SPINNER_TYPES)
131131

132+
# Track whether a progress bar is already active to prevent nested Live displays
133+
_progress_bar_active = False
134+
132135

133136
@contextmanager
134137
def progress_bar(
@@ -138,28 +141,38 @@ def progress_bar(
138141
139142
If revert_to_print is True, falls back to printing a single logger.info message
140143
instead of showing a progress bar.
144+
145+
If a progress bar is already active, yields a dummy task ID to avoid Rich's
146+
LiveError from nested Live displays.
141147
"""
148+
global _progress_bar_active
149+
142150
if is_LSP_enabled():
143151
lsp_log(LspTextMessage(text=message, takes_time=True))
144152
yield
145153
return
146154

147-
if revert_to_print:
148-
logger.info(message)
155+
if revert_to_print or _progress_bar_active:
156+
if revert_to_print:
157+
logger.info(message)
149158

150159
# Create a fake task ID since we still need to yield something
151160
yield DummyTask().id
152161
else:
153-
progress = Progress(
154-
SpinnerColumn(next(spinners)),
155-
*Progress.get_default_columns(),
156-
TimeElapsedColumn(),
157-
console=console,
158-
transient=transient,
159-
)
160-
task = progress.add_task(message, total=None)
161-
with progress:
162-
yield task
162+
_progress_bar_active = True
163+
try:
164+
progress = Progress(
165+
SpinnerColumn(next(spinners)),
166+
*Progress.get_default_columns(),
167+
TimeElapsedColumn(),
168+
console=console,
169+
transient=transient,
170+
)
171+
task = progress.add_task(message, total=None)
172+
with progress:
173+
yield task
174+
finally:
175+
_progress_bar_active = False
163176

164177

165178
@contextmanager

codeflash/code_utils/code_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -436,7 +436,7 @@ def extract_unique_errors(pytest_output: str) -> set[str]:
436436
pattern = r"^E\s+(.*)$"
437437

438438
for error_message in re.findall(pattern, pytest_output, re.MULTILINE):
439-
error_message = error_message.strip() # noqa: PLW2901
439+
error_message = error_message.strip()
440440
if error_message:
441441
unique_errors.add(error_message)
442442

codeflash/code_utils/config_js.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -105,7 +105,7 @@ def detect_module_root(project_root: Path, package_data: dict[str, Any]) -> str:
105105
return "."
106106

107107

108-
def detect_test_runner(project_root: Path, package_data: dict[str, Any]) -> str: # noqa: ARG001
108+
def detect_test_runner(project_root: Path, package_data: dict[str, Any]) -> str:
109109
"""Detect test runner from devDependencies or scripts.test.
110110
111111
Detection order:
@@ -144,7 +144,7 @@ def detect_test_runner(project_root: Path, package_data: dict[str, Any]) -> str:
144144
return "jest"
145145

146146

147-
def detect_formatter(project_root: Path, package_data: dict[str, Any]) -> list[str] | None: # noqa: ARG001
147+
def detect_formatter(project_root: Path, package_data: dict[str, Any]) -> list[str] | None:
148148
"""Detect formatter from devDependencies.
149149
150150
Detection order:

codeflash/code_utils/deduplicate_code.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,10 +14,7 @@
1414

1515

1616
def normalize_code(
17-
code: str,
18-
remove_docstrings: bool = True,
19-
return_ast_dump: bool = False,
20-
language: str | None = None,
17+
code: str, remove_docstrings: bool = True, return_ast_dump: bool = False, language: str | None = None
2118
) -> str:
2219
"""Normalize code by parsing, cleaning, and normalizing variable names.
2320

codeflash/code_utils/instrument_existing_tests.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,7 @@ def find_and_update_line_node(
8989
# it's much more efficient to visit nodes manually. We'll only descend into expressions/statements.
9090

9191
# Helper for manual walk
92-
def iter_ast_calls(node): # noqa: ANN202
92+
def iter_ast_calls(node):
9393
# Generator to yield each ast.Call in test_node, preserves node identity
9494
stack = [node]
9595
while stack:
@@ -102,7 +102,7 @@ def iter_ast_calls(node): # noqa: ANN202
102102
if isinstance(value, list):
103103
for item in reversed(value):
104104
if isinstance(item, ast.AST):
105-
stack.append(item) # noqa: PERF401
105+
stack.append(item)
106106
elif isinstance(value, ast.AST):
107107
stack.append(value)
108108

codeflash/context/code_context_extractor.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -46,8 +46,8 @@ def build_testgen_context(
4646
helpers_of_fto_dict: dict[Path, set[FunctionSource]],
4747
helpers_of_helpers_dict: dict[Path, set[FunctionSource]],
4848
project_root_path: Path,
49-
remove_docstrings: bool, # noqa: FBT001
50-
include_imported_classes: bool, # noqa: FBT001
49+
remove_docstrings: bool,
50+
include_imported_classes: bool,
5151
) -> CodeStringsMarkdown:
5252
"""Build testgen context with optional imported class definitions and external base inits."""
5353
testgen_context = extract_code_markdown_context_from_files(

codeflash/discovery/functions_to_optimize.py

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -216,6 +216,32 @@ def get_files_for_language(
216216
return files
217217

218218

219+
def _is_js_ts_function_exported(file_path: Path, function_name: str) -> tuple[bool, str | None]:
220+
"""Check if a JavaScript/TypeScript function is exported from its module.
221+
222+
For JS/TS, functions that are not exported cannot be imported by tests,
223+
making them impossible to optimize.
224+
225+
Args:
226+
file_path: Path to the source file.
227+
function_name: Name of the function to check.
228+
229+
Returns:
230+
Tuple of (is_exported, export_name). export_name may be 'default' for default exports.
231+
232+
"""
233+
from codeflash.languages.treesitter_utils import get_analyzer_for_file
234+
235+
try:
236+
source = file_path.read_text(encoding="utf-8")
237+
analyzer = get_analyzer_for_file(file_path)
238+
return analyzer.is_function_exported(source, function_name)
239+
except Exception as e:
240+
logger.debug(f"Failed to check export status for {function_name}: {e}")
241+
# Return True to avoid blocking in case of errors
242+
return True, None
243+
244+
219245
def _find_all_functions_in_python_file(file_path: Path) -> dict[Path, list[FunctionToOptimize]]:
220246
"""Find all optimizable functions in a Python file using AST parsing.
221247
@@ -338,6 +364,36 @@ def get_functions_to_optimize(
338364
exit_with_message(
339365
f"Function {only_get_this_function} not found in file {file}\nor the function does not have a 'return' statement or is a property"
340366
)
367+
368+
# For JavaScript/TypeScript, verify that the function (or its parent class) is exported
369+
# Non-exported functions cannot be imported by tests
370+
if found_function.language in ("javascript", "typescript"):
371+
# For class methods, check if the parent class is exported
372+
# For standalone functions, check if the function itself is exported
373+
if found_function.parents:
374+
# It's a class method - check if the class is exported
375+
name_to_check = found_function.top_level_parent_name
376+
else:
377+
# It's a standalone function - check if the function is exported
378+
name_to_check = found_function.function_name
379+
380+
is_exported, export_name = _is_js_ts_function_exported(file, name_to_check)
381+
if not is_exported:
382+
if found_function.parents:
383+
logger.debug(
384+
f"Class '{name_to_check}' containing method '{found_function.function_name}' "
385+
f"is not exported from {file}. "
386+
f"In JavaScript/TypeScript, only exported classes/functions can be optimized "
387+
f"because tests need to import them."
388+
)
389+
else:
390+
logger.debug(
391+
f"Function '{found_function.function_name}' is not exported from {file}. "
392+
f"In JavaScript/TypeScript, only exported functions can be optimized because "
393+
f"tests need to import them."
394+
)
395+
return {}, 0, None
396+
341397
functions[file] = [found_function]
342398
else:
343399
logger.info("Finding all functions modified in the current git diff ...")

0 commit comments

Comments
 (0)