Skip to content

Commit 48b5ff3

Browse files
committed
refactor: merge duplicate CST pruning functions into single parameterized function
Consolidated prune_cst_for_read_only_code and prune_cst_for_testgen_code into prune_cst_for_context with include_target_in_output and include_init_dunder flags.
1 parent 7b33e8b commit 48b5ff3

1 file changed

Lines changed: 63 additions & 121 deletions

File tree

codeflash/context/code_context_extractor.py

Lines changed: 63 additions & 121 deletions
Original file line numberDiff line numberDiff line change
@@ -969,12 +969,22 @@ def parse_code_and_prune_cst(
969969
if code_context_type == CodeContextType.READ_WRITABLE:
970970
filtered_node, found_target = prune_cst_for_read_writable_code(module, target_functions, defs_with_usages)
971971
elif code_context_type == CodeContextType.READ_ONLY:
972-
filtered_node, found_target = prune_cst_for_read_only_code(
973-
module, target_functions, helpers_of_helper_functions, remove_docstrings=remove_docstrings
972+
filtered_node, found_target = prune_cst_for_context(
973+
module,
974+
target_functions,
975+
helpers_of_helper_functions,
976+
remove_docstrings=remove_docstrings,
977+
include_target_in_output=False,
978+
include_init_dunder=False,
974979
)
975980
elif code_context_type == CodeContextType.TESTGEN:
976-
filtered_node, found_target = prune_cst_for_testgen_code(
977-
module, target_functions, helpers_of_helper_functions, remove_docstrings=remove_docstrings
981+
filtered_node, found_target = prune_cst_for_context(
982+
module,
983+
target_functions,
984+
helpers_of_helper_functions,
985+
remove_docstrings=remove_docstrings,
986+
include_target_in_output=True,
987+
include_init_dunder=True,
978988
)
979989
elif code_context_type == CodeContextType.HASHING:
980990
filtered_node, found_target = prune_cst_for_code_hashing(module, target_functions)
@@ -1198,17 +1208,29 @@ def prune_cst_for_code_hashing( # noqa: PLR0911
11981208
return (node.with_changes(**updates) if updates else node), True
11991209

12001210

1201-
def prune_cst_for_read_only_code( # noqa: PLR0911
1211+
def prune_cst_for_context( # noqa: PLR0911
12021212
node: cst.CSTNode,
12031213
target_functions: set[str],
12041214
helpers_of_helper_functions: set[str],
12051215
prefix: str = "",
12061216
remove_docstrings: bool = False, # noqa: FBT001, FBT002
1217+
include_target_in_output: bool = False, # noqa: FBT001, FBT002
1218+
include_init_dunder: bool = False, # noqa: FBT001, FBT002
12071219
) -> tuple[cst.CSTNode | None, bool]:
1208-
"""Recursively filter the node for read-only context.
1220+
"""Recursively filter the node for code context extraction.
12091221
1210-
Returns
1211-
-------
1222+
Args:
1223+
node: The CST node to filter
1224+
target_functions: Set of qualified function names that are targets
1225+
helpers_of_helper_functions: Set of helper function qualified names
1226+
prefix: Current qualified name prefix (for class methods)
1227+
remove_docstrings: Whether to remove docstrings from output
1228+
include_target_in_output: If True, include target functions in output (testgen mode)
1229+
If False, exclude target functions (read-only mode)
1230+
include_init_dunder: If True, include __init__ in dunder methods (testgen mode)
1231+
If False, exclude __init__ from dunder methods (read-only mode)
1232+
1233+
Returns:
12121234
(filtered_node, found_target):
12131235
filtered_node: The modified CST node or None if it should be removed.
12141236
found_target: True if a target function was found in this node's subtree.
@@ -1219,17 +1241,28 @@ def prune_cst_for_read_only_code( # noqa: PLR0911
12191241

12201242
if isinstance(node, cst.FunctionDef):
12211243
qualified_name = f"{prefix}.{node.name.value}" if prefix else node.name.value
1222-
# If it's a target function, remove it but mark found_target = True
1244+
1245+
# Check if it's a helper of helper function
12231246
if qualified_name in helpers_of_helper_functions:
1247+
if remove_docstrings and isinstance(node.body, cst.IndentedBlock):
1248+
return node.with_changes(body=remove_docstring_from_body(node.body)), True
12241249
return node, True
1250+
1251+
# Check if it's a target function
12251252
if qualified_name in target_functions:
1253+
if include_target_in_output:
1254+
if remove_docstrings and isinstance(node.body, cst.IndentedBlock):
1255+
return node.with_changes(body=remove_docstring_from_body(node.body)), True
1256+
return node, True
12261257
return None, True
1227-
# Keep only dunder methods
1228-
if is_dunder_method(node.name.value) and node.name.value != "__init__":
1258+
1259+
# Check dunder methods
1260+
# For read-only mode, exclude __init__; for testgen mode, include all dunders
1261+
if is_dunder_method(node.name.value) and (include_init_dunder or node.name.value != "__init__"):
12291262
if remove_docstrings and isinstance(node.body, cst.IndentedBlock):
1230-
new_body = remove_docstring_from_body(node.body)
1231-
return node.with_changes(body=new_body), False
1263+
return node.with_changes(body=remove_docstring_from_body(node.body)), False
12321264
return node, False
1265+
12331266
return None, False
12341267

12351268
if isinstance(node, cst.ClassDef):
@@ -1246,114 +1279,14 @@ def prune_cst_for_read_only_code( # noqa: PLR0911
12461279
found_in_class = False
12471280
new_class_body: list[CSTNode] = []
12481281
for stmt in node.body.body:
1249-
filtered, found_target = prune_cst_for_read_only_code(
1250-
stmt, target_functions, helpers_of_helper_functions, class_prefix, remove_docstrings=remove_docstrings
1251-
)
1252-
found_in_class |= found_target
1253-
if filtered:
1254-
new_class_body.append(filtered)
1255-
1256-
if not found_in_class:
1257-
return None, False
1258-
1259-
if remove_docstrings:
1260-
return node.with_changes(
1261-
body=remove_docstring_from_body(node.body.with_changes(body=new_class_body))
1262-
) if new_class_body else None, True
1263-
return node.with_changes(body=node.body.with_changes(body=new_class_body)) if new_class_body else None, True
1264-
1265-
# For other nodes, keep the node and recursively filter children
1266-
section_names = get_section_names(node)
1267-
if not section_names:
1268-
return node, False
1269-
1270-
updates: dict[str, list[cst.CSTNode] | cst.CSTNode] = {}
1271-
found_any_target = False
1272-
1273-
for section in section_names:
1274-
original_content = getattr(node, section, None)
1275-
if isinstance(original_content, (list, tuple)):
1276-
new_children = []
1277-
section_found_target = False
1278-
for child in original_content:
1279-
filtered, found_target = prune_cst_for_read_only_code(
1280-
child, target_functions, helpers_of_helper_functions, prefix, remove_docstrings=remove_docstrings
1281-
)
1282-
if filtered:
1283-
new_children.append(filtered)
1284-
section_found_target |= found_target
1285-
1286-
if section_found_target or new_children:
1287-
found_any_target |= section_found_target
1288-
updates[section] = new_children
1289-
elif original_content is not None:
1290-
filtered, found_target = prune_cst_for_read_only_code(
1291-
original_content,
1282+
filtered, found_target = prune_cst_for_context(
1283+
stmt,
12921284
target_functions,
12931285
helpers_of_helper_functions,
1294-
prefix,
1286+
class_prefix,
12951287
remove_docstrings=remove_docstrings,
1296-
)
1297-
found_any_target |= found_target
1298-
if filtered:
1299-
updates[section] = filtered
1300-
if updates:
1301-
return (node.with_changes(**updates), found_any_target)
1302-
1303-
return None, False
1304-
1305-
1306-
def prune_cst_for_testgen_code( # noqa: PLR0911
1307-
node: cst.CSTNode,
1308-
target_functions: set[str],
1309-
helpers_of_helper_functions: set[str],
1310-
prefix: str = "",
1311-
remove_docstrings: bool = False, # noqa: FBT001, FBT002
1312-
) -> tuple[cst.CSTNode | None, bool]:
1313-
"""Recursively filter the node for testgen context.
1314-
1315-
Returns
1316-
-------
1317-
(filtered_node, found_target):
1318-
filtered_node: The modified CST node or None if it should be removed.
1319-
found_target: True if a target function was found in this node's subtree.
1320-
1321-
"""
1322-
if isinstance(node, (cst.Import, cst.ImportFrom)):
1323-
return None, False
1324-
1325-
if isinstance(node, cst.FunctionDef):
1326-
qualified_name = f"{prefix}.{node.name.value}" if prefix else node.name.value
1327-
# If it's a target function, remove it but mark found_target = True
1328-
if qualified_name in helpers_of_helper_functions or qualified_name in target_functions:
1329-
if remove_docstrings and isinstance(node.body, cst.IndentedBlock):
1330-
new_body = remove_docstring_from_body(node.body)
1331-
return node.with_changes(body=new_body), True
1332-
return node, True
1333-
# Keep all dunder methods
1334-
if is_dunder_method(node.name.value):
1335-
if remove_docstrings and isinstance(node.body, cst.IndentedBlock):
1336-
new_body = remove_docstring_from_body(node.body)
1337-
return node.with_changes(body=new_body), False
1338-
return node, False
1339-
return None, False
1340-
1341-
if isinstance(node, cst.ClassDef):
1342-
# Do not recurse into nested classes
1343-
if prefix:
1344-
return None, False
1345-
# Assuming always an IndentedBlock
1346-
if not isinstance(node.body, cst.IndentedBlock):
1347-
raise ValueError("ClassDef body is not an IndentedBlock") # noqa: TRY004
1348-
1349-
class_prefix = f"{prefix}.{node.name.value}" if prefix else node.name.value
1350-
1351-
# First pass: detect if there is a target function in the class
1352-
found_in_class = False
1353-
new_class_body: list[CSTNode] = []
1354-
for stmt in node.body.body:
1355-
filtered, found_target = prune_cst_for_testgen_code(
1356-
stmt, target_functions, helpers_of_helper_functions, class_prefix, remove_docstrings=remove_docstrings
1288+
include_target_in_output=include_target_in_output,
1289+
include_init_dunder=include_init_dunder,
13571290
)
13581291
found_in_class |= found_target
13591292
if filtered:
@@ -1382,8 +1315,14 @@ def prune_cst_for_testgen_code( # noqa: PLR0911
13821315
new_children = []
13831316
section_found_target = False
13841317
for child in original_content:
1385-
filtered, found_target = prune_cst_for_testgen_code(
1386-
child, target_functions, helpers_of_helper_functions, prefix, remove_docstrings=remove_docstrings
1318+
filtered, found_target = prune_cst_for_context(
1319+
child,
1320+
target_functions,
1321+
helpers_of_helper_functions,
1322+
prefix,
1323+
remove_docstrings=remove_docstrings,
1324+
include_target_in_output=include_target_in_output,
1325+
include_init_dunder=include_init_dunder,
13871326
)
13881327
if filtered:
13891328
new_children.append(filtered)
@@ -1393,16 +1332,19 @@ def prune_cst_for_testgen_code( # noqa: PLR0911
13931332
found_any_target |= section_found_target
13941333
updates[section] = new_children
13951334
elif original_content is not None:
1396-
filtered, found_target = prune_cst_for_testgen_code(
1335+
filtered, found_target = prune_cst_for_context(
13971336
original_content,
13981337
target_functions,
13991338
helpers_of_helper_functions,
14001339
prefix,
14011340
remove_docstrings=remove_docstrings,
1341+
include_target_in_output=include_target_in_output,
1342+
include_init_dunder=include_init_dunder,
14021343
)
14031344
found_any_target |= found_target
14041345
if filtered:
14051346
updates[section] = filtered
1347+
14061348
if updates:
14071349
return (node.with_changes(**updates), found_any_target)
14081350

0 commit comments

Comments
 (0)