@@ -265,27 +265,22 @@ def visit_Import(self, node: ast.Import) -> None:
265265
266266 def visit_Assign (self , node : ast .Assign ) -> None :
267267 """Track variable assignments, especially class instantiations."""
268- if self .found_any_target_function :
269- return
270-
271- # Check if the assignment is a class instantiation
268+ # Always track instance assignments, even if we've found a target function
269+ # This is needed for the PyTorch nn.Module pattern where model(x) calls forward(x)
272270 value = node .value
273271 if isinstance (value , ast .Call ) and isinstance (value .func , ast .Name ):
274272 class_name = value .func .id
275273 if class_name in self .imported_modules :
276274 # Map the variable to the actual class name (handling aliases)
277275 original_class = self .alias_mapping .get (class_name , class_name )
278- # Use list comprehension for direct assignment to instance_mapping, reducing loop overhead
279276 targets = node .targets
280- instance_mapping = self .instance_mapping
281- # since ast.Name nodes are heavily used, avoid local lookup for isinstance
282- # and reuse locals for faster attribute access
283277 for target in targets :
284278 if isinstance (target , ast .Name ):
285- instance_mapping [target .id ] = original_class
279+ self . instance_mapping [target .id ] = original_class
286280
287- # Continue visiting child nodes
288- self .generic_visit (node )
281+ # Continue visiting child nodes if we haven't found a target function yet
282+ if not self .found_any_target_function :
283+ self .generic_visit (node )
289284
290285 def visit_ImportFrom (self , node : ast .ImportFrom ) -> None :
291286 """Handle 'from module import name' statements."""
@@ -405,7 +400,7 @@ def visit_Attribute(self, node: ast.Attribute) -> None:
405400 ast .NodeVisitor .generic_visit (self , node )
406401
407402 def visit_Call (self , node : ast .Call ) -> None :
408- """Handle function calls, particularly __import__."""
403+ """Handle function calls, particularly __import__ and instance calls for nn.Module.forward ."""
409404 if self .found_any_target_function :
410405 return
411406
@@ -415,6 +410,19 @@ def visit_Call(self, node: ast.Call) -> None:
415410 # When __import__ is used, any target function could potentially be imported
416411 # Be conservative and assume it might import target functions
417412
413+ # Check if this is a call on an instance variable (PyTorch nn.Module pattern)
414+ # When model = AlexNet(...) and we call model(input_data), this invokes forward()
415+ if isinstance (node .func , ast .Name ):
416+ instance_name = node .func .id
417+ if instance_name in self .instance_mapping :
418+ class_name = self .instance_mapping [instance_name ]
419+ # Check if ClassName.forward is in our target functions
420+ roots_possible = self ._dot_methods .get ("forward" )
421+ if roots_possible and class_name in roots_possible :
422+ self .found_any_target_function = True
423+ self .found_qualified_name = self ._class_method_to_target [(class_name , "forward" )]
424+ return
425+
418426 self .generic_visit (node )
419427
420428 def visit_Name (self , node : ast .Name ) -> None :
@@ -495,6 +503,68 @@ def _fast_generic_visit(self, node: ast.AST) -> None:
495503 append ((value ._fields , value ))
496504
497505
506+ class InstanceMappingExtractor (ast .NodeVisitor ):
507+ """Simple visitor to extract instance-to-class mappings from a file.
508+
509+ This is needed for detecting PyTorch nn.Module.forward calls where model(x) calls forward(x).
510+ """
511+
512+ def __init__ (self ) -> None :
513+ self .imported_modules : set [str ] = set ()
514+ self .alias_mapping : dict [str , str ] = {}
515+ self .instance_mapping : dict [str , str ] = {}
516+
517+ def visit_Import (self , node : ast .Import ) -> None :
518+ for alias in node .names :
519+ module_name = alias .asname if alias .asname else alias .name
520+ self .imported_modules .add (module_name )
521+ self .generic_visit (node )
522+
523+ def visit_ImportFrom (self , node : ast .ImportFrom ) -> None :
524+ if not node .module :
525+ return
526+ for alias in node .names :
527+ if alias .name == "*" :
528+ continue
529+ imported_name = alias .asname if alias .asname else alias .name
530+ self .imported_modules .add (imported_name )
531+ if alias .asname :
532+ self .alias_mapping [imported_name ] = alias .name
533+ self .generic_visit (node )
534+
535+ def visit_Assign (self , node : ast .Assign ) -> None :
536+ value = node .value
537+ if isinstance (value , ast .Call ) and isinstance (value .func , ast .Name ):
538+ class_name = value .func .id
539+ if class_name in self .imported_modules :
540+ original_class = self .alias_mapping .get (class_name , class_name )
541+ for target in node .targets :
542+ if isinstance (target , ast .Name ):
543+ self .instance_mapping [target .id ] = original_class
544+ self .generic_visit (node )
545+
546+
547+ def extract_instance_mapping (test_file_path : Path ) -> dict [str , str ]:
548+ """Extract instance-to-class mappings from a test file.
549+
550+ Args:
551+ test_file_path: Path to the test file.
552+
553+ Returns:
554+ Dictionary mapping instance variable names to class names.
555+
556+ """
557+ try :
558+ with test_file_path .open ("r" , encoding = "utf-8" ) as f :
559+ source_code = f .read ()
560+ tree = ast .parse (source_code , filename = str (test_file_path ))
561+ extractor = InstanceMappingExtractor ()
562+ extractor .visit (tree )
563+ return extractor .instance_mapping
564+ except (SyntaxError , FileNotFoundError ):
565+ return {}
566+
567+
498568def analyze_imports_in_test_file (test_file_path : Path | str , target_functions : set [str ]) -> bool :
499569 """Analyze a test file to see if it imports any of the target functions."""
500570 try :
@@ -879,6 +949,10 @@ def process_test_files(
879949 top_level_functions = {name .name : name for name in all_names_top if name .type == "function" }
880950 top_level_classes = {name .name : name for name in all_names_top if name .type == "class" }
881951
952+ # Get instance-to-class mappings for PyTorch nn.Module.forward detection
953+ # When model = AlexNet(...) and model(x) is called, it invokes forward(x)
954+ instance_to_class_mapping = extract_instance_mapping (test_file ) if functions_to_optimize else {}
955+
882956 except Exception as e :
883957 logger .debug (f"Failed to get jedi script for { test_file } : { e } " )
884958 progress .advance (task_id )
@@ -1017,6 +1091,61 @@ def process_test_files(
10171091 num_discovered_replay_tests += 1
10181092
10191093 num_discovered_tests += 1
1094+
1095+ # Also check for PyTorch nn.Module pattern: model(x) -> forward(x)
1096+ # When an instance variable is called, it invokes the forward method
1097+ if name .name in instance_to_class_mapping :
1098+ class_name = instance_to_class_mapping [name .name ]
1099+ for func_to_opt in functions_to_optimize :
1100+ # Check if the target is ClassName.forward
1101+ if (
1102+ func_to_opt .function_name == "forward"
1103+ and func_to_opt .top_level_parent_name == class_name
1104+ ):
1105+ qualified_name_with_modules = func_to_opt .qualified_name_with_modules_from_root (
1106+ project_root_path
1107+ )
1108+
1109+ for test_func in test_functions_by_name [scope ]:
1110+ if test_func .parameters is not None :
1111+ if test_framework == "pytest" :
1112+ scope_test_function = (
1113+ f"{ test_func .function_name } [{ test_func .parameters } ]"
1114+ )
1115+ else : # unittest
1116+ scope_test_function = (
1117+ f"{ test_func .function_name } _{ test_func .parameters } "
1118+ )
1119+ else :
1120+ scope_test_function = test_func .function_name
1121+
1122+ function_to_test_map [qualified_name_with_modules ].add (
1123+ FunctionCalledInTest (
1124+ tests_in_file = TestsInFile (
1125+ test_file = test_file ,
1126+ test_class = test_func .test_class ,
1127+ test_function = scope_test_function ,
1128+ test_type = test_func .test_type ,
1129+ ),
1130+ position = CodePosition (line_no = name .line , col_no = name .column ),
1131+ )
1132+ )
1133+ tests_cache .insert_test (
1134+ file_path = str (test_file ),
1135+ file_hash = file_hash ,
1136+ qualified_name_with_modules_from_root = qualified_name_with_modules ,
1137+ function_name = scope ,
1138+ test_class = test_func .test_class or "" ,
1139+ test_function = scope_test_function ,
1140+ test_type = test_func .test_type ,
1141+ line_number = name .line ,
1142+ col_number = name .column ,
1143+ )
1144+
1145+ if test_func .test_type == TestType .REPLAY_TEST :
1146+ num_discovered_replay_tests += 1
1147+
1148+ num_discovered_tests += 1
10201149 continue
10211150 definition_obj = definition [0 ]
10221151 definition_path = str (definition_obj .module_path )
0 commit comments