Skip to content

Commit 6ab2710

Browse files
authored
Merge pull request #1033 from codeflash-ai/optimize-tracer-replay
simplify E2E replay test to reduce load in CI
2 parents f828b1c + d9161e7 commit 6ab2710

5 files changed

Lines changed: 23 additions & 90 deletions

File tree

Binary file not shown.
Lines changed: 9 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -1,74 +1,39 @@
11
from concurrent.futures import ThreadPoolExecutor
2-
from time import sleep
32

43

54
def funcA(number):
6-
number = number if number < 1000 else 1000
5+
number = number if number < 100 else 100
76
k = 0
8-
for i in range(number * 100):
7+
for i in range(number * 10):
98
k += i
10-
# Simplify the for loop by using sum with a range object
119
j = sum(range(number))
12-
13-
# Use a generator expression directly in join for more efficiency
1410
return " ".join(str(i) for i in range(number))
1511

1612

1713
def test_threadpool() -> None:
18-
pool = ThreadPoolExecutor(max_workers=3)
19-
args = list(range(10, 31, 10))
14+
pool = ThreadPoolExecutor(max_workers=2)
15+
args = [5, 10, 15]
2016
result = pool.map(funcA, args)
2117

2218
for r in result:
2319
print(r)
2420

2521
class AlexNet:
26-
def __init__(self, num_classes=1000):
22+
def __init__(self, num_classes=10):
2723
self.num_classes = num_classes
28-
self.features_size = 256 * 6 * 6
2924

3025
def forward(self, x):
31-
features = self._extract_features(x)
32-
33-
output = self._classify(features)
34-
return output
35-
36-
def _extract_features(self, x):
37-
result = []
38-
for i in range(len(x)):
39-
pass
40-
41-
return result
42-
43-
def _classify(self, features):
44-
total = sum(features)
45-
return [total % self.num_classes for _ in features]
46-
47-
class SimpleModel:
48-
@staticmethod
49-
def predict(data):
50-
result = []
51-
sleep(0.1) # can be optimized away
52-
for i in range(500):
53-
for x in data:
54-
computation = 0
55-
computation += x * i ** 2
56-
result.append(computation)
57-
return result
58-
59-
@classmethod
60-
def create_default(cls):
61-
return cls()
26+
result = 0
27+
for val in x:
28+
result += val * val
29+
return result % self.num_classes
6230

6331

6432
def test_models():
6533
model = AlexNet(num_classes=10)
6634
input_data = [1, 2, 3, 4, 5]
6735
result = model.forward(input_data)
6836

69-
model2 = SimpleModel.create_default()
70-
prediction = model2.predict(input_data)
71-
7237
if __name__ == "__main__":
7338
test_threadpool()
7439
test_models()

tests/scripts/end_to_end_test_tracer_replay.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ def run_test(expected_improvement_pct: int) -> bool:
1010
min_improvement_x=0.1,
1111
expected_unit_tests_count=None, # Tracer creates replay tests dynamically, skip validation
1212
coverage_expectations=[
13-
CoverageExpectation(function_name="funcA", expected_coverage=100.0, expected_lines=[6, 7, 8, 9, 11, 14])
13+
CoverageExpectation(function_name="funcA", expected_coverage=100.0, expected_lines=[5, 6, 7, 8, 9, 10])
1414
],
1515
)
1616
cwd = (

tests/scripts/end_to_end_test_utilities.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -262,9 +262,9 @@ def run_trace_test(cwd: pathlib.Path, config: TestConfig, expected_improvement_p
262262
if not functions_traced:
263263
logging.error("Failed to find traced functions in output")
264264
return False
265-
if int(functions_traced.group(1)) != 13:
265+
if int(functions_traced.group(1)) != 8:
266266
logging.error(functions_traced.groups())
267-
logging.error("Expected 13 traced functions")
267+
logging.error("Expected 8 traced functions")
268268
return False
269269

270270
# Validate optimization results (from optimization phase)

tests/test_function_ranker.py

Lines changed: 11 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -58,8 +58,8 @@ def test_load_function_stats(function_ranker):
5858
# Verify funcA specific values
5959
assert func_a_stats["function_name"] == "funcA"
6060
assert func_a_stats["call_count"] == 1
61-
assert func_a_stats["own_time_ns"] == 63000
62-
assert func_a_stats["cumulative_time_ns"] == 5443000
61+
assert func_a_stats["own_time_ns"] == 153000
62+
assert func_a_stats["cumulative_time_ns"] == 1324000
6363

6464

6565
def test_get_function_addressable_time(function_ranker, workload_functions):
@@ -71,10 +71,10 @@ def test_get_function_addressable_time(function_ranker, workload_functions):
7171

7272
assert func_a is not None
7373
addressable_time = function_ranker.get_function_addressable_time(func_a)
74-
74+
7575
# Expected addressable time: own_time + (time_in_callees / call_count)
76-
# = 63000 + ((5443000 - 63000) / 1) = 5443000
77-
assert addressable_time == 5443000
76+
# = 153000 + ((1324000 - 153000) / 1) = 1324000
77+
assert addressable_time == 1324000
7878

7979

8080
def test_rank_functions(function_ranker, workload_functions):
@@ -107,9 +107,9 @@ def test_get_function_stats_summary(function_ranker, workload_functions):
107107

108108
assert stats is not None
109109
assert stats["function_name"] == "funcA"
110-
assert stats["own_time_ns"] == 63000
111-
assert stats["cumulative_time_ns"] == 5443000
112-
assert stats["addressable_time_ns"] == 5443000
110+
assert stats["own_time_ns"] == 153000
111+
assert stats["cumulative_time_ns"] == 1324000
112+
assert stats["addressable_time_ns"] == 1324000
113113

114114

115115

@@ -128,40 +128,8 @@ def test_importance_calculation(function_ranker):
128128

129129
assert func_a_stats is not None
130130
importance = func_a_stats["own_time_ns"] / total_program_time
131-
132-
# funcA importance should be approximately 0.57% (63000/10968000)
133-
assert abs(importance - 0.0057) < 0.001
131+
132+
# funcA importance should be approximately 1.9% (153000/7958000)
133+
assert abs(importance - 0.019) < 0.01
134134

135135

136-
def test_simple_model_predict_stats(function_ranker, workload_functions):
137-
# Find SimpleModel::predict function
138-
predict_func = None
139-
for func in workload_functions:
140-
if func.function_name == "predict":
141-
predict_func = func
142-
break
143-
144-
assert predict_func is not None
145-
146-
stats = function_ranker.get_function_stats_summary(predict_func)
147-
assert stats is not None
148-
assert stats["function_name"] == "predict"
149-
assert stats["call_count"] == 1
150-
assert stats["own_time_ns"] == 2289000
151-
assert stats["cumulative_time_ns"] == 4017000
152-
assert stats["addressable_time_ns"] == 4017000
153-
154-
# Test addressable time calculation
155-
addressable_time = function_ranker.get_function_addressable_time(predict_func)
156-
# Expected addressable time: own_time + (time_in_callees / call_count)
157-
# = 2289000 + ((4017000 - 2289000) / 1) = 4017000
158-
assert addressable_time == 4017000
159-
160-
# Test importance calculation for predict function
161-
total_program_time = sum(
162-
s["own_time_ns"] for s in function_ranker._function_stats.values()
163-
if s.get("own_time_ns", 0) > 0
164-
)
165-
importance = stats["own_time_ns"] / total_program_time
166-
# predict importance should be approximately 20.9% (2289000/10968000)
167-
assert abs(importance - 0.209) < 0.01

0 commit comments

Comments
 (0)