Skip to content

Commit 2c3603c

Browse files
committed
Add advanced AST passes for SDPA and MM performance optimization
1 parent b4441ee commit 2c3603c

2 files changed

Lines changed: 190 additions & 0 deletions

File tree

src/ninetoothed/advanced_passes.py

Lines changed: 186 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,186 @@
1+
import ast
2+
3+
4+
def get_const_val(node):
5+
if isinstance(node, ast.Constant):
6+
return node.value
7+
if getattr(ast, "Num", None) and isinstance(node, ast.Num):
8+
return node.n
9+
if isinstance(node, ast.UnaryOp) and isinstance(node.op, ast.USub):
10+
val = get_const_val(node.operand)
11+
if val is not None:
12+
return -val
13+
return None
14+
15+
16+
class SafeAlgebraPass(ast.NodeTransformer):
17+
"""
18+
Safely folds constants and eliminates redundant arithmetic operations.
19+
Strictly preserves vector semantics for expressions containing 'arange'
20+
to prevent broadcasting bugs (e.g., scalarizing zero-strided tensors).
21+
"""
22+
23+
def visit_BinOp(self, node):
24+
node = self.generic_visit(node)
25+
lv = get_const_val(node.left)
26+
rv = get_const_val(node.right)
27+
28+
# Core protection: Never optimize vector expressions containing 'arange' into a pure scalar 0.
29+
has_arange = "arange" in ast.dump(node)
30+
31+
# 1. Basic constant folding
32+
if lv is not None and rv is not None:
33+
try:
34+
if isinstance(node.op, ast.Add):
35+
return ast.Constant(value=lv + rv)
36+
if isinstance(node.op, ast.Sub):
37+
return ast.Constant(value=lv - rv)
38+
if isinstance(node.op, ast.Mult):
39+
return ast.Constant(value=lv * rv)
40+
if isinstance(node.op, ast.FloorDiv) and rv != 0:
41+
return ast.Constant(value=lv // rv)
42+
except Exception:
43+
pass
44+
45+
# 2. Redundant addition/subtraction elimination (e.g., (X - C) + C -> X)
46+
if isinstance(node.op, ast.Add) and isinstance(node.left, ast.BinOp) and isinstance(node.left.op, ast.Sub):
47+
if get_const_val(node.right) == get_const_val(node.left.right) and get_const_val(node.right) is not None:
48+
return node.left.left
49+
if isinstance(node.op, ast.Sub) and isinstance(node.left, ast.BinOp) and isinstance(node.left.op, ast.Add):
50+
if get_const_val(node.right) == get_const_val(node.left.right) and get_const_val(node.right) is not None:
51+
return node.left.left
52+
53+
# 3. Multiplication by 0/1 and Addition/Subtraction by 0
54+
if isinstance(node.op, ast.Mult):
55+
if lv == 0 or rv == 0:
56+
if not has_arange:
57+
return ast.Constant(value=0) # Prevents scalarization
58+
if lv == 1:
59+
return node.right
60+
if rv == 1:
61+
return node.left
62+
elif isinstance(node.op, ast.Add):
63+
if lv == 0:
64+
return node.right
65+
if rv == 0:
66+
return node.left
67+
elif isinstance(node.op, ast.Sub):
68+
if rv == 0:
69+
return node.left
70+
if ast.dump(node.left) == ast.dump(node.right):
71+
if not has_arange:
72+
return ast.Constant(value=0) # Prevents scalarization
73+
elif isinstance(node.op, ast.FloorDiv):
74+
if rv == 1:
75+
return node.left
76+
if lv == 0:
77+
if not has_arange:
78+
return ast.Constant(value=0) # Prevents scalarization
79+
80+
return node
81+
82+
83+
class UltimateBCEPass(ast.NodeTransformer):
84+
"""
85+
Bounds Checking Elimination (BCE).
86+
Eliminates redundant bounds checks by statically evaluating trivially true conditions.
87+
"""
88+
89+
def __init__(self):
90+
self.loop_bounds = {}
91+
92+
def visit_For(self, node):
93+
if isinstance(node.target, ast.Name):
94+
self.loop_bounds[node.target.id] = True
95+
self.generic_visit(node)
96+
if isinstance(node.target, ast.Name):
97+
self.loop_bounds.pop(node.target.id, None)
98+
return node
99+
100+
def visit_Compare(self, node):
101+
node = self.generic_visit(node)
102+
if len(node.ops) != 1:
103+
return node
104+
105+
left = node.left
106+
op = node.ops[0]
107+
right = node.comparators[0]
108+
l_dump = ast.dump(left)
109+
r_dump = ast.dump(right)
110+
lv = get_const_val(left)
111+
rv = get_const_val(right)
112+
113+
if lv is not None and rv is not None:
114+
try:
115+
if isinstance(op, ast.Lt):
116+
return ast.Constant(value=lv < rv)
117+
if isinstance(op, ast.GtE):
118+
return ast.Constant(value=lv >= rv)
119+
except Exception:
120+
pass
121+
122+
is_loop_var = isinstance(left, ast.Name) and left.id in self.loop_bounds
123+
if is_loop_var and isinstance(op, ast.GtE) and rv == 0:
124+
return ast.Constant(value=True)
125+
if is_loop_var and isinstance(op, ast.Lt) and "FloorDiv" in r_dump and "size" in r_dump:
126+
return ast.Constant(value=True)
127+
128+
if isinstance(op, ast.GtE) and rv == 0:
129+
if "arange" in l_dump and "index" not in l_dump and "pid" not in l_dump:
130+
return ast.Constant(value=True)
131+
if isinstance(op, ast.Lt) and "arange" in l_dump and "BLOCK_SIZE" in l_dump and "BLOCK_SIZE" in r_dump:
132+
if "size" not in r_dump:
133+
return ast.Constant(value=True)
134+
135+
return node
136+
137+
138+
class MaskCSEPass(ast.NodeTransformer):
139+
"""
140+
Mask-level Common Subexpression Elimination (CSE).
141+
Simplifies mask expressions by flattening nested BitAnd operations and removing duplicate conditions.
142+
"""
143+
144+
def visit_BinOp(self, node):
145+
node = self.generic_visit(node)
146+
if isinstance(node.op, ast.BitAnd):
147+
148+
def flatten_and(n):
149+
if isinstance(n, ast.BinOp) and isinstance(n.op, ast.BitAnd):
150+
return flatten_and(n.left) + flatten_and(n.right)
151+
return [n]
152+
153+
terms = flatten_and(node)
154+
unique_terms = []
155+
seen = set()
156+
for t in terms:
157+
if getattr(t, "value", None) is True:
158+
continue
159+
d = ast.dump(t)
160+
if d not in seen:
161+
seen.add(d)
162+
unique_terms.append(t)
163+
164+
if not unique_terms:
165+
return ast.Constant(value=True)
166+
167+
res = unique_terms[0]
168+
for t in unique_terms[1:]:
169+
res = ast.BinOp(left=res, op=ast.BitAnd(), right=t)
170+
return res
171+
172+
return node
173+
174+
175+
def apply_advanced_optimizations(tree):
176+
ast.fix_missing_locations(tree)
177+
old_dump = ""
178+
# Iteratively apply optimizations until a fixed point is reached
179+
while old_dump != ast.dump(tree):
180+
old_dump = ast.dump(tree)
181+
tree = SafeAlgebraPass().visit(tree)
182+
tree = UltimateBCEPass().visit(tree)
183+
tree = MaskCSEPass().visit(tree)
184+
ast.fix_missing_locations(tree)
185+
186+
return tree

src/ninetoothed/generation.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -103,6 +103,10 @@ def _find_dependencies(func):
103103
self.visit(tree)
104104
Tritonizer().visit(tree)
105105
_BinOpSimplifier().visit(tree)
106+
107+
from .advanced_passes import apply_advanced_optimizations
108+
tree = apply_advanced_optimizations(tree)
109+
106110
ast.fix_missing_locations(tree)
107111

108112
if prettify:

0 commit comments

Comments
 (0)