Skip to content

Commit b5b55d4

Browse files
instrument standalone test calls and some other instrumentaion fixes
1 parent 6362aef commit b5b55d4

5 files changed

Lines changed: 341 additions & 80 deletions

File tree

codeflash/code_utils/edit_generated_tests.py

Lines changed: 64 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -154,6 +154,7 @@ def _is_python_file(file_path: Path) -> bool:
154154
"""Check if a file is a Python file."""
155155
return file_path.suffix == ".py"
156156

157+
157158
# TODO:{self} Needs cleanup for jest logic in else block
158159
def unique_inv_id(inv_id_runtimes: dict[InvocationId, list[int]], tests_project_rootdir: Path) -> dict[str, int]:
159160
unique_inv_ids: dict[str, int] = {}
@@ -174,9 +175,22 @@ def unique_inv_id(inv_id_runtimes: dict[InvocationId, list[int]], tests_project_
174175
else:
175176
# Check for Jest test file extensions (e.g., tests.fibonacci.test.ts)
176177
# These need special handling to avoid converting .test.ts -> /test/ts
177-
jest_test_extensions = (".test.ts", ".test.js", ".test.tsx", ".test.jsx",
178-
".spec.ts", ".spec.js", ".spec.tsx", ".spec.jsx",
179-
".ts", ".js", ".tsx", ".jsx", ".mjs", ".mts")
178+
jest_test_extensions = (
179+
".test.ts",
180+
".test.js",
181+
".test.tsx",
182+
".test.jsx",
183+
".spec.ts",
184+
".spec.js",
185+
".spec.tsx",
186+
".spec.jsx",
187+
".ts",
188+
".js",
189+
".tsx",
190+
".jsx",
191+
".mjs",
192+
".mts",
193+
)
180194
matched_ext = None
181195
for ext in jest_test_extensions:
182196
if test_module_path.endswith(ext):
@@ -186,7 +200,7 @@ def unique_inv_id(inv_id_runtimes: dict[InvocationId, list[int]], tests_project_
186200
if matched_ext:
187201
# JavaScript/TypeScript: convert module-style path to file path
188202
# "tests.fibonacci__perfonlyinstrumented.test.ts" -> "tests/fibonacci__perfonlyinstrumented.test.ts"
189-
base_path = test_module_path[:-len(matched_ext)]
203+
base_path = test_module_path[: -len(matched_ext)]
190204
file_path = base_path.replace(".", os.sep) + matched_ext
191205
# Check if the module path includes the tests directory name
192206
tests_dir_name = tests_project_rootdir.name
@@ -321,9 +335,7 @@ def _compile_function_patterns(test_functions_to_remove: list[str]) -> list[re.P
321335
_CODEFLASH_REQUIRE_PATTERN = re.compile(
322336
r"(const|let|var)\s+(\w+)\s*=\s*require\s*\(\s*['\"]\.?/?codeflash-jest-helper['\"]\s*\)"
323337
)
324-
_CODEFLASH_IMPORT_PATTERN = re.compile(
325-
r"import\s+(?:\*\s+as\s+)?(\w+)\s+from\s+['\"]\.?/?codeflash-jest-helper['\"]"
326-
)
338+
_CODEFLASH_IMPORT_PATTERN = re.compile(r"import\s+(?:\*\s+as\s+)?(\w+)\s+from\s+['\"]\.?/?codeflash-jest-helper['\"]")
327339

328340

329341
def normalize_codeflash_imports(source: str) -> str:
@@ -344,18 +356,55 @@ def normalize_codeflash_imports(source: str) -> str:
344356
345357
"""
346358
# Replace CommonJS require
347-
source = _CODEFLASH_REQUIRE_PATTERN.sub(
348-
r"\1 \2 = require('codeflash')",
349-
source,
350-
)
359+
source = _CODEFLASH_REQUIRE_PATTERN.sub(r"\1 \2 = require('codeflash')", source)
351360
# Replace ES module import
352-
source = _CODEFLASH_IMPORT_PATTERN.sub(
353-
r"import \1 from 'codeflash'",
354-
source,
355-
)
361+
source = _CODEFLASH_IMPORT_PATTERN.sub(r"import \1 from 'codeflash'", source)
356362
return source
357363

358364

365+
def inject_test_globals(generated_tests: GeneratedTestsList) -> GeneratedTestsList:
366+
# TODO: inside the prompt tell the llm if it should import jest functions or it's already injected in the global window
367+
"""Inject test globals into all generated tests.
368+
369+
Args:
370+
generated_tests: List of generated tests.
371+
372+
Returns:
373+
Generated tests with test globals injected.
374+
375+
"""
376+
# we only inject test globals for esm modules
377+
global_import = (
378+
"import { jest, describe, it, expect, beforeEach, afterEach, beforeAll, test } from '@jest/globals'\n"
379+
)
380+
381+
for test in generated_tests.generated_tests:
382+
test.generated_original_test_source = global_import + test.generated_original_test_source
383+
test.instrumented_behavior_test_source = global_import + test.instrumented_behavior_test_source
384+
test.instrumented_perf_test_source = global_import + test.instrumented_perf_test_source
385+
return generated_tests
386+
387+
388+
def disable_ts_check(generated_tests: GeneratedTestsList) -> GeneratedTestsList:
389+
"""Disable TypeScript type checking in all generated tests.
390+
391+
Args:
392+
generated_tests: List of generated tests.
393+
394+
Returns:
395+
Generated tests with TypeScript type checking disabled.
396+
397+
"""
398+
# we only inject test globals for esm modules
399+
ts_nocheck = "// @ts-nocheck\n"
400+
401+
for test in generated_tests.generated_tests:
402+
test.generated_original_test_source = ts_nocheck + test.generated_original_test_source
403+
test.instrumented_behavior_test_source = ts_nocheck + test.instrumented_behavior_test_source
404+
test.instrumented_perf_test_source = ts_nocheck + test.instrumented_perf_test_source
405+
return generated_tests
406+
407+
359408
def normalize_generated_tests_imports(generated_tests: GeneratedTestsList) -> GeneratedTestsList:
360409
"""Normalize codeflash imports in all generated tests.
361410

codeflash/languages/current.py

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -34,14 +34,15 @@
3434
from codeflash.languages.base import LanguageSupport
3535

3636
# Module-level singleton for the current language
37-
_current_language: Language = Language.PYTHON
37+
_current_language: Language | None = None
3838

3939

4040
def current_language() -> Language:
4141
"""Get the current language being used in this codeflash session.
4242
4343
Returns:
4444
The current Language enum value.
45+
4546
"""
4647
return _current_language
4748

@@ -54,13 +55,13 @@ def set_current_language(language: Language | str) -> None:
5455
5556
Args:
5657
language: Either a Language enum value or a string like "python", "javascript", "typescript".
58+
5759
"""
5860
global _current_language
5961

60-
if isinstance(language, str):
61-
_current_language = Language(language)
62-
else:
63-
_current_language = language
62+
if _current_language is not None:
63+
return
64+
_current_language = Language(language) if isinstance(language, str) else language
6465

6566

6667
def reset_current_language() -> None:
@@ -77,6 +78,7 @@ def is_python() -> bool:
7778
7879
Returns:
7980
True if the current language is Python.
81+
8082
"""
8183
return _current_language == Language.PYTHON
8284

@@ -89,6 +91,7 @@ def is_javascript() -> bool:
8991
9092
Returns:
9193
True if the current language is JavaScript or TypeScript.
94+
9295
"""
9396
return _current_language in (Language.JAVASCRIPT, Language.TYPESCRIPT)
9497

@@ -98,6 +101,7 @@ def is_typescript() -> bool:
98101
99102
Returns:
100103
True if the current language is TypeScript.
104+
101105
"""
102106
return _current_language == Language.TYPESCRIPT
103107

@@ -107,7 +111,8 @@ def current_language_support() -> LanguageSupport:
107111
108112
Returns:
109113
The LanguageSupport instance for the current language.
114+
110115
"""
111116
from codeflash.languages.registry import get_language_support
112117

113-
return get_language_support(_current_language)
118+
return get_language_support(_current_language)

0 commit comments

Comments
 (0)