@@ -922,7 +922,7 @@ def get_imported_names(import_node: cst.Import | cst.ImportFrom) -> set[str]:
922922 names .add (alias .name .value )
923923 elif isinstance (alias .name , cst .Attribute ):
924924 # import foo.bar -> accessible as "foo"
925- base = alias .name
925+ base : cst . BaseExpression = alias .name
926926 while isinstance (base , cst .Attribute ):
927927 base = base .value
928928 if isinstance (base , cst .Name ):
@@ -969,12 +969,22 @@ def parse_code_and_prune_cst(
969969 if code_context_type == CodeContextType .READ_WRITABLE :
970970 filtered_node , found_target = prune_cst_for_read_writable_code (module , target_functions , defs_with_usages )
971971 elif code_context_type == CodeContextType .READ_ONLY :
972- filtered_node , found_target = prune_cst_for_read_only_code (
973- module , target_functions , helpers_of_helper_functions , remove_docstrings = remove_docstrings
972+ filtered_node , found_target = prune_cst_for_context (
973+ module ,
974+ target_functions ,
975+ helpers_of_helper_functions ,
976+ remove_docstrings = remove_docstrings ,
977+ include_target_in_output = False ,
978+ include_init_dunder = False ,
974979 )
975980 elif code_context_type == CodeContextType .TESTGEN :
976- filtered_node , found_target = prune_cst_for_testgen_code (
977- module , target_functions , helpers_of_helper_functions , remove_docstrings = remove_docstrings
981+ filtered_node , found_target = prune_cst_for_context (
982+ module ,
983+ target_functions ,
984+ helpers_of_helper_functions ,
985+ remove_docstrings = remove_docstrings ,
986+ include_target_in_output = True ,
987+ include_init_dunder = True ,
978988 )
979989 elif code_context_type == CodeContextType .HASHING :
980990 filtered_node , found_target = prune_cst_for_code_hashing (module , target_functions )
@@ -1198,17 +1208,29 @@ def prune_cst_for_code_hashing( # noqa: PLR0911
11981208 return (node .with_changes (** updates ) if updates else node ), True
11991209
12001210
1201- def prune_cst_for_read_only_code ( # noqa: PLR0911
1211+ def prune_cst_for_context ( # noqa: PLR0911
12021212 node : cst .CSTNode ,
12031213 target_functions : set [str ],
12041214 helpers_of_helper_functions : set [str ],
12051215 prefix : str = "" ,
12061216 remove_docstrings : bool = False , # noqa: FBT001, FBT002
1217+ include_target_in_output : bool = False , # noqa: FBT001, FBT002
1218+ include_init_dunder : bool = False , # noqa: FBT001, FBT002
12071219) -> tuple [cst .CSTNode | None , bool ]:
1208- """Recursively filter the node for read-only context.
1220+ """Recursively filter the node for code context extraction .
12091221
1210- Returns
1211- -------
1222+ Args:
1223+ node: The CST node to filter
1224+ target_functions: Set of qualified function names that are targets
1225+ helpers_of_helper_functions: Set of helper function qualified names
1226+ prefix: Current qualified name prefix (for class methods)
1227+ remove_docstrings: Whether to remove docstrings from output
1228+ include_target_in_output: If True, include target functions in output (testgen mode)
1229+ If False, exclude target functions (read-only mode)
1230+ include_init_dunder: If True, include __init__ in dunder methods (testgen mode)
1231+ If False, exclude __init__ from dunder methods (read-only mode)
1232+
1233+ Returns:
12121234 (filtered_node, found_target):
12131235 filtered_node: The modified CST node or None if it should be removed.
12141236 found_target: True if a target function was found in this node's subtree.
@@ -1219,17 +1241,28 @@ def prune_cst_for_read_only_code( # noqa: PLR0911
12191241
12201242 if isinstance (node , cst .FunctionDef ):
12211243 qualified_name = f"{ prefix } .{ node .name .value } " if prefix else node .name .value
1222- # If it's a target function, remove it but mark found_target = True
1244+
1245+ # Check if it's a helper of helper function
12231246 if qualified_name in helpers_of_helper_functions :
1247+ if remove_docstrings and isinstance (node .body , cst .IndentedBlock ):
1248+ return node .with_changes (body = remove_docstring_from_body (node .body )), True
12241249 return node , True
1250+
1251+ # Check if it's a target function
12251252 if qualified_name in target_functions :
1253+ if include_target_in_output :
1254+ if remove_docstrings and isinstance (node .body , cst .IndentedBlock ):
1255+ return node .with_changes (body = remove_docstring_from_body (node .body )), True
1256+ return node , True
12261257 return None , True
1227- # Keep only dunder methods
1228- if is_dunder_method (node .name .value ) and node .name .value != "__init__" :
1258+
1259+ # Check dunder methods
1260+ # For read-only mode, exclude __init__; for testgen mode, include all dunders
1261+ if is_dunder_method (node .name .value ) and (include_init_dunder or node .name .value != "__init__" ):
12291262 if remove_docstrings and isinstance (node .body , cst .IndentedBlock ):
1230- new_body = remove_docstring_from_body (node .body )
1231- return node .with_changes (body = new_body ), False
1263+ return node .with_changes (body = remove_docstring_from_body (node .body )), False
12321264 return node , False
1265+
12331266 return None , False
12341267
12351268 if isinstance (node , cst .ClassDef ):
@@ -1246,114 +1279,14 @@ def prune_cst_for_read_only_code( # noqa: PLR0911
12461279 found_in_class = False
12471280 new_class_body : list [CSTNode ] = []
12481281 for stmt in node .body .body :
1249- filtered , found_target = prune_cst_for_read_only_code (
1250- stmt , target_functions , helpers_of_helper_functions , class_prefix , remove_docstrings = remove_docstrings
1251- )
1252- found_in_class |= found_target
1253- if filtered :
1254- new_class_body .append (filtered )
1255-
1256- if not found_in_class :
1257- return None , False
1258-
1259- if remove_docstrings :
1260- return node .with_changes (
1261- body = remove_docstring_from_body (node .body .with_changes (body = new_class_body ))
1262- ) if new_class_body else None , True
1263- return node .with_changes (body = node .body .with_changes (body = new_class_body )) if new_class_body else None , True
1264-
1265- # For other nodes, keep the node and recursively filter children
1266- section_names = get_section_names (node )
1267- if not section_names :
1268- return node , False
1269-
1270- updates : dict [str , list [cst .CSTNode ] | cst .CSTNode ] = {}
1271- found_any_target = False
1272-
1273- for section in section_names :
1274- original_content = getattr (node , section , None )
1275- if isinstance (original_content , (list , tuple )):
1276- new_children = []
1277- section_found_target = False
1278- for child in original_content :
1279- filtered , found_target = prune_cst_for_read_only_code (
1280- child , target_functions , helpers_of_helper_functions , prefix , remove_docstrings = remove_docstrings
1281- )
1282- if filtered :
1283- new_children .append (filtered )
1284- section_found_target |= found_target
1285-
1286- if section_found_target or new_children :
1287- found_any_target |= section_found_target
1288- updates [section ] = new_children
1289- elif original_content is not None :
1290- filtered , found_target = prune_cst_for_read_only_code (
1291- original_content ,
1282+ filtered , found_target = prune_cst_for_context (
1283+ stmt ,
12921284 target_functions ,
12931285 helpers_of_helper_functions ,
1294- prefix ,
1286+ class_prefix ,
12951287 remove_docstrings = remove_docstrings ,
1296- )
1297- found_any_target |= found_target
1298- if filtered :
1299- updates [section ] = filtered
1300- if updates :
1301- return (node .with_changes (** updates ), found_any_target )
1302-
1303- return None , False
1304-
1305-
1306- def prune_cst_for_testgen_code ( # noqa: PLR0911
1307- node : cst .CSTNode ,
1308- target_functions : set [str ],
1309- helpers_of_helper_functions : set [str ],
1310- prefix : str = "" ,
1311- remove_docstrings : bool = False , # noqa: FBT001, FBT002
1312- ) -> tuple [cst .CSTNode | None , bool ]:
1313- """Recursively filter the node for testgen context.
1314-
1315- Returns
1316- -------
1317- (filtered_node, found_target):
1318- filtered_node: The modified CST node or None if it should be removed.
1319- found_target: True if a target function was found in this node's subtree.
1320-
1321- """
1322- if isinstance (node , (cst .Import , cst .ImportFrom )):
1323- return None , False
1324-
1325- if isinstance (node , cst .FunctionDef ):
1326- qualified_name = f"{ prefix } .{ node .name .value } " if prefix else node .name .value
1327- # If it's a target function, remove it but mark found_target = True
1328- if qualified_name in helpers_of_helper_functions or qualified_name in target_functions :
1329- if remove_docstrings and isinstance (node .body , cst .IndentedBlock ):
1330- new_body = remove_docstring_from_body (node .body )
1331- return node .with_changes (body = new_body ), True
1332- return node , True
1333- # Keep all dunder methods
1334- if is_dunder_method (node .name .value ):
1335- if remove_docstrings and isinstance (node .body , cst .IndentedBlock ):
1336- new_body = remove_docstring_from_body (node .body )
1337- return node .with_changes (body = new_body ), False
1338- return node , False
1339- return None , False
1340-
1341- if isinstance (node , cst .ClassDef ):
1342- # Do not recurse into nested classes
1343- if prefix :
1344- return None , False
1345- # Assuming always an IndentedBlock
1346- if not isinstance (node .body , cst .IndentedBlock ):
1347- raise ValueError ("ClassDef body is not an IndentedBlock" ) # noqa: TRY004
1348-
1349- class_prefix = f"{ prefix } .{ node .name .value } " if prefix else node .name .value
1350-
1351- # First pass: detect if there is a target function in the class
1352- found_in_class = False
1353- new_class_body : list [CSTNode ] = []
1354- for stmt in node .body .body :
1355- filtered , found_target = prune_cst_for_testgen_code (
1356- stmt , target_functions , helpers_of_helper_functions , class_prefix , remove_docstrings = remove_docstrings
1288+ include_target_in_output = include_target_in_output ,
1289+ include_init_dunder = include_init_dunder ,
13571290 )
13581291 found_in_class |= found_target
13591292 if filtered :
@@ -1382,8 +1315,14 @@ def prune_cst_for_testgen_code( # noqa: PLR0911
13821315 new_children = []
13831316 section_found_target = False
13841317 for child in original_content :
1385- filtered , found_target = prune_cst_for_testgen_code (
1386- child , target_functions , helpers_of_helper_functions , prefix , remove_docstrings = remove_docstrings
1318+ filtered , found_target = prune_cst_for_context (
1319+ child ,
1320+ target_functions ,
1321+ helpers_of_helper_functions ,
1322+ prefix ,
1323+ remove_docstrings = remove_docstrings ,
1324+ include_target_in_output = include_target_in_output ,
1325+ include_init_dunder = include_init_dunder ,
13871326 )
13881327 if filtered :
13891328 new_children .append (filtered )
@@ -1393,16 +1332,19 @@ def prune_cst_for_testgen_code( # noqa: PLR0911
13931332 found_any_target |= section_found_target
13941333 updates [section ] = new_children
13951334 elif original_content is not None :
1396- filtered , found_target = prune_cst_for_testgen_code (
1335+ filtered , found_target = prune_cst_for_context (
13971336 original_content ,
13981337 target_functions ,
13991338 helpers_of_helper_functions ,
14001339 prefix ,
14011340 remove_docstrings = remove_docstrings ,
1341+ include_target_in_output = include_target_in_output ,
1342+ include_init_dunder = include_init_dunder ,
14021343 )
14031344 found_any_target |= found_target
14041345 if filtered :
14051346 updates [section ] = filtered
1347+
14061348 if updates :
14071349 return (node .with_changes (** updates ), found_any_target )
14081350
0 commit comments