Skip to content

Commit 6e3d6e7

Browse files
committed
almost ready
1 parent 5292d67 commit 6e3d6e7

3 files changed

Lines changed: 216 additions & 12 deletions

File tree

code_to_optimize/sample_code.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,24 @@
11
from functools import partial
2+
from typing import Any
23

34
import jax.numpy as jnp
45
import numpy as np
56
import tensorflow as tf
67
import torch
78
from jax import lax
9+
from torch import nn
810

911

12+
class AlexNet(nn.Module):
13+
def __init__(self, num_classes=10, *args: Any, **kwargs: Any):
14+
super().__init__(*args, **kwargs)
15+
self.num_classes = num_classes
16+
self.layer = nn.Linear(5,10)
17+
18+
def forward(self, x):
19+
x = self.layer(x)
20+
return x
21+
1022
def tridiagonal_solve(a: np.ndarray, b: np.ndarray, c: np.ndarray, d: np.ndarray) -> np.ndarray:
1123
n = len(b)
1224

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
import torch
2+
3+
from code_to_optimize.sample_code import AlexNet
4+
5+
def test_models():
6+
torch.manual_seed(42)
7+
model = AlexNet(num_classes=10)
8+
input_data = torch.randn(2,5)
9+
assert torch.allclose(model(input_data), torch.Tensor([
10+
[0.2655223608, 0.3765228391, -0.4080065191, 0.3314782381,
11+
0.6830080152, 0.5442206264, 0.1187968627, 0.2742837071,
12+
0.3680166304, 0.3558489084],
13+
[-0.9252133369, -0.8182569146, -0.5546661019, 0.6546985507,
14+
-0.1227166206, -0.0484373420, -0.5192810893, -0.4771555662,
15+
0.2874411345, -0.4801278412]]))
16+
17+
def test_models1():
18+
torch.manual_seed(42)
19+
model = AlexNet(num_classes=10)
20+
input_data = torch.randn(2,5)
21+
assert torch.allclose(model(input_data), torch.Tensor([
22+
[0.2655223608, 0.3765228391, -0.4080065191, 0.3314782381,
23+
0.6830080152, 0.5442206264, 0.1187968627, 0.2742837071,
24+
0.3680166304, 0.3558489084],
25+
[-0.9252133369, -0.8182569146, -0.5546661019, 0.6546985507,
26+
-0.1227166206, -0.0484373420, -0.5192810893, -0.4771555662,
27+
0.2874411345, -0.4801278412]]))
28+
29+
def test_models2():
30+
torch.manual_seed(42)
31+
model = AlexNet(num_classes=10)
32+
input_data = torch.randn(2,5)
33+
assert torch.allclose(model(input_data), torch.Tensor([
34+
[0.2655223608, 0.3765228391, -0.4080065191, 0.3314782381,
35+
0.6830080152, 0.5442206264, 0.1187968627, 0.2742837071,
36+
0.3680166304, 0.3558489084],
37+
[-0.9252133369, -0.8182569146, -0.5546661019, 0.6546985507,
38+
-0.1227166206, -0.0484373420, -0.5192810893, -0.4771555662,
39+
0.2874411345, -0.4801278412]]))
40+
41+
def test_models3():
42+
torch.manual_seed(42)
43+
model = AlexNet(num_classes=10)
44+
input_data = torch.randn(2,5)
45+
assert torch.allclose(model(input_data), torch.Tensor([
46+
[0.2655223608, 0.3765228391, -0.4080065191, 0.3314782381,
47+
0.6830080152, 0.5442206264, 0.1187968627, 0.2742837071,
48+
0.3680166304, 0.3558489084],
49+
[-0.9252133369, -0.8182569146, -0.5546661019, 0.6546985507,
50+
-0.1227166206, -0.0484373420, -0.5192810893, -0.4771555662,
51+
0.2874411345, -0.4801278412]]))
52+
53+
def test_models4():
54+
torch.manual_seed(42)
55+
model = AlexNet(num_classes=10)
56+
input_data = torch.randn(2,5)
57+
assert torch.allclose(model(input_data), torch.Tensor([
58+
[0.2655223608, 0.3765228391, -0.4080065191, 0.3314782381,
59+
0.6830080152, 0.5442206264, 0.1187968627, 0.2742837071,
60+
0.3680166304, 0.3558489084],
61+
[-0.9252133369, -0.8182569146, -0.5546661019, 0.6546985507,
62+
-0.1227166206, -0.0484373420, -0.5192810893, -0.4771555662,
63+
0.2874411345, -0.4801278412]]))

codeflash/discovery/discover_unit_tests.py

Lines changed: 141 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -265,27 +265,22 @@ def visit_Import(self, node: ast.Import) -> None:
265265

266266
def visit_Assign(self, node: ast.Assign) -> None:
267267
"""Track variable assignments, especially class instantiations."""
268-
if self.found_any_target_function:
269-
return
270-
271-
# Check if the assignment is a class instantiation
268+
# Always track instance assignments, even if we've found a target function
269+
# This is needed for the PyTorch nn.Module pattern where model(x) calls forward(x)
272270
value = node.value
273271
if isinstance(value, ast.Call) and isinstance(value.func, ast.Name):
274272
class_name = value.func.id
275273
if class_name in self.imported_modules:
276274
# Map the variable to the actual class name (handling aliases)
277275
original_class = self.alias_mapping.get(class_name, class_name)
278-
# Use list comprehension for direct assignment to instance_mapping, reducing loop overhead
279276
targets = node.targets
280-
instance_mapping = self.instance_mapping
281-
# since ast.Name nodes are heavily used, avoid local lookup for isinstance
282-
# and reuse locals for faster attribute access
283277
for target in targets:
284278
if isinstance(target, ast.Name):
285-
instance_mapping[target.id] = original_class
279+
self.instance_mapping[target.id] = original_class
286280

287-
# Continue visiting child nodes
288-
self.generic_visit(node)
281+
# Continue visiting child nodes if we haven't found a target function yet
282+
if not self.found_any_target_function:
283+
self.generic_visit(node)
289284

290285
def visit_ImportFrom(self, node: ast.ImportFrom) -> None:
291286
"""Handle 'from module import name' statements."""
@@ -405,7 +400,7 @@ def visit_Attribute(self, node: ast.Attribute) -> None:
405400
ast.NodeVisitor.generic_visit(self, node)
406401

407402
def visit_Call(self, node: ast.Call) -> None:
408-
"""Handle function calls, particularly __import__."""
403+
"""Handle function calls, particularly __import__ and instance calls for nn.Module.forward."""
409404
if self.found_any_target_function:
410405
return
411406

@@ -415,6 +410,19 @@ def visit_Call(self, node: ast.Call) -> None:
415410
# When __import__ is used, any target function could potentially be imported
416411
# Be conservative and assume it might import target functions
417412

413+
# Check if this is a call on an instance variable (PyTorch nn.Module pattern)
414+
# When model = AlexNet(...) and we call model(input_data), this invokes forward()
415+
if isinstance(node.func, ast.Name):
416+
instance_name = node.func.id
417+
if instance_name in self.instance_mapping:
418+
class_name = self.instance_mapping[instance_name]
419+
# Check if ClassName.forward is in our target functions
420+
roots_possible = self._dot_methods.get("forward")
421+
if roots_possible and class_name in roots_possible:
422+
self.found_any_target_function = True
423+
self.found_qualified_name = self._class_method_to_target[(class_name, "forward")]
424+
return
425+
418426
self.generic_visit(node)
419427

420428
def visit_Name(self, node: ast.Name) -> None:
@@ -495,6 +503,68 @@ def _fast_generic_visit(self, node: ast.AST) -> None:
495503
append((value._fields, value))
496504

497505

506+
class InstanceMappingExtractor(ast.NodeVisitor):
507+
"""Simple visitor to extract instance-to-class mappings from a file.
508+
509+
This is needed for detecting PyTorch nn.Module.forward calls where model(x) calls forward(x).
510+
"""
511+
512+
def __init__(self) -> None:
513+
self.imported_modules: set[str] = set()
514+
self.alias_mapping: dict[str, str] = {}
515+
self.instance_mapping: dict[str, str] = {}
516+
517+
def visit_Import(self, node: ast.Import) -> None:
518+
for alias in node.names:
519+
module_name = alias.asname if alias.asname else alias.name
520+
self.imported_modules.add(module_name)
521+
self.generic_visit(node)
522+
523+
def visit_ImportFrom(self, node: ast.ImportFrom) -> None:
524+
if not node.module:
525+
return
526+
for alias in node.names:
527+
if alias.name == "*":
528+
continue
529+
imported_name = alias.asname if alias.asname else alias.name
530+
self.imported_modules.add(imported_name)
531+
if alias.asname:
532+
self.alias_mapping[imported_name] = alias.name
533+
self.generic_visit(node)
534+
535+
def visit_Assign(self, node: ast.Assign) -> None:
536+
value = node.value
537+
if isinstance(value, ast.Call) and isinstance(value.func, ast.Name):
538+
class_name = value.func.id
539+
if class_name in self.imported_modules:
540+
original_class = self.alias_mapping.get(class_name, class_name)
541+
for target in node.targets:
542+
if isinstance(target, ast.Name):
543+
self.instance_mapping[target.id] = original_class
544+
self.generic_visit(node)
545+
546+
547+
def extract_instance_mapping(test_file_path: Path) -> dict[str, str]:
548+
"""Extract instance-to-class mappings from a test file.
549+
550+
Args:
551+
test_file_path: Path to the test file.
552+
553+
Returns:
554+
Dictionary mapping instance variable names to class names.
555+
556+
"""
557+
try:
558+
with test_file_path.open("r", encoding="utf-8") as f:
559+
source_code = f.read()
560+
tree = ast.parse(source_code, filename=str(test_file_path))
561+
extractor = InstanceMappingExtractor()
562+
extractor.visit(tree)
563+
return extractor.instance_mapping
564+
except (SyntaxError, FileNotFoundError):
565+
return {}
566+
567+
498568
def analyze_imports_in_test_file(test_file_path: Path | str, target_functions: set[str]) -> bool:
499569
"""Analyze a test file to see if it imports any of the target functions."""
500570
try:
@@ -879,6 +949,10 @@ def process_test_files(
879949
top_level_functions = {name.name: name for name in all_names_top if name.type == "function"}
880950
top_level_classes = {name.name: name for name in all_names_top if name.type == "class"}
881951

952+
# Get instance-to-class mappings for PyTorch nn.Module.forward detection
953+
# When model = AlexNet(...) and model(x) is called, it invokes forward(x)
954+
instance_to_class_mapping = extract_instance_mapping(test_file) if functions_to_optimize else {}
955+
882956
except Exception as e:
883957
logger.debug(f"Failed to get jedi script for {test_file}: {e}")
884958
progress.advance(task_id)
@@ -1017,6 +1091,61 @@ def process_test_files(
10171091
num_discovered_replay_tests += 1
10181092

10191093
num_discovered_tests += 1
1094+
1095+
# Also check for PyTorch nn.Module pattern: model(x) -> forward(x)
1096+
# When an instance variable is called, it invokes the forward method
1097+
if name.name in instance_to_class_mapping:
1098+
class_name = instance_to_class_mapping[name.name]
1099+
for func_to_opt in functions_to_optimize:
1100+
# Check if the target is ClassName.forward
1101+
if (
1102+
func_to_opt.function_name == "forward"
1103+
and func_to_opt.top_level_parent_name == class_name
1104+
):
1105+
qualified_name_with_modules = func_to_opt.qualified_name_with_modules_from_root(
1106+
project_root_path
1107+
)
1108+
1109+
for test_func in test_functions_by_name[scope]:
1110+
if test_func.parameters is not None:
1111+
if test_framework == "pytest":
1112+
scope_test_function = (
1113+
f"{test_func.function_name}[{test_func.parameters}]"
1114+
)
1115+
else: # unittest
1116+
scope_test_function = (
1117+
f"{test_func.function_name}_{test_func.parameters}"
1118+
)
1119+
else:
1120+
scope_test_function = test_func.function_name
1121+
1122+
function_to_test_map[qualified_name_with_modules].add(
1123+
FunctionCalledInTest(
1124+
tests_in_file=TestsInFile(
1125+
test_file=test_file,
1126+
test_class=test_func.test_class,
1127+
test_function=scope_test_function,
1128+
test_type=test_func.test_type,
1129+
),
1130+
position=CodePosition(line_no=name.line, col_no=name.column),
1131+
)
1132+
)
1133+
tests_cache.insert_test(
1134+
file_path=str(test_file),
1135+
file_hash=file_hash,
1136+
qualified_name_with_modules_from_root=qualified_name_with_modules,
1137+
function_name=scope,
1138+
test_class=test_func.test_class or "",
1139+
test_function=scope_test_function,
1140+
test_type=test_func.test_type,
1141+
line_number=name.line,
1142+
col_number=name.column,
1143+
)
1144+
1145+
if test_func.test_type == TestType.REPLAY_TEST:
1146+
num_discovered_replay_tests += 1
1147+
1148+
num_discovered_tests += 1
10201149
continue
10211150
definition_obj = definition[0]
10221151
definition_path = str(definition_obj.module_path)

0 commit comments

Comments
 (0)