1+ import ast
2+
3+
4+ def get_const_val (node ):
5+ if isinstance (node , ast .Constant ):
6+ return node .value
7+ if getattr (ast , "Num" , None ) and isinstance (node , ast .Num ):
8+ return node .n
9+ if isinstance (node , ast .UnaryOp ) and isinstance (node .op , ast .USub ):
10+ val = get_const_val (node .operand )
11+ if val is not None :
12+ return - val
13+ return None
14+
15+
16+ class SafeAlgebraPass (ast .NodeTransformer ):
17+ """
18+ Safely folds constants and eliminates redundant arithmetic operations.
19+ Strictly preserves vector semantics for expressions containing 'arange'
20+ to prevent broadcasting bugs (e.g., scalarizing zero-strided tensors).
21+ """
22+
23+ def visit_BinOp (self , node ):
24+ node = self .generic_visit (node )
25+ lv = get_const_val (node .left )
26+ rv = get_const_val (node .right )
27+
28+ # Core protection: Never optimize vector expressions containing 'arange' into a pure scalar 0.
29+ has_arange = "arange" in ast .dump (node )
30+
31+ # 1. Basic constant folding
32+ if lv is not None and rv is not None :
33+ try :
34+ if isinstance (node .op , ast .Add ):
35+ return ast .Constant (value = lv + rv )
36+ if isinstance (node .op , ast .Sub ):
37+ return ast .Constant (value = lv - rv )
38+ if isinstance (node .op , ast .Mult ):
39+ return ast .Constant (value = lv * rv )
40+ if isinstance (node .op , ast .FloorDiv ) and rv != 0 :
41+ return ast .Constant (value = lv // rv )
42+ except Exception :
43+ pass
44+
45+ # 2. Redundant addition/subtraction elimination (e.g., (X - C) + C -> X)
46+ if isinstance (node .op , ast .Add ) and isinstance (node .left , ast .BinOp ) and isinstance (node .left .op , ast .Sub ):
47+ if get_const_val (node .right ) == get_const_val (node .left .right ) and get_const_val (node .right ) is not None :
48+ return node .left .left
49+ if isinstance (node .op , ast .Sub ) and isinstance (node .left , ast .BinOp ) and isinstance (node .left .op , ast .Add ):
50+ if get_const_val (node .right ) == get_const_val (node .left .right ) and get_const_val (node .right ) is not None :
51+ return node .left .left
52+
53+ # 3. Multiplication by 0/1 and Addition/Subtraction by 0
54+ if isinstance (node .op , ast .Mult ):
55+ if lv == 0 or rv == 0 :
56+ if not has_arange :
57+ return ast .Constant (value = 0 ) # Prevents scalarization
58+ if lv == 1 :
59+ return node .right
60+ if rv == 1 :
61+ return node .left
62+ elif isinstance (node .op , ast .Add ):
63+ if lv == 0 :
64+ return node .right
65+ if rv == 0 :
66+ return node .left
67+ elif isinstance (node .op , ast .Sub ):
68+ if rv == 0 :
69+ return node .left
70+ if ast .dump (node .left ) == ast .dump (node .right ):
71+ if not has_arange :
72+ return ast .Constant (value = 0 ) # Prevents scalarization
73+ elif isinstance (node .op , ast .FloorDiv ):
74+ if rv == 1 :
75+ return node .left
76+ if lv == 0 :
77+ if not has_arange :
78+ return ast .Constant (value = 0 ) # Prevents scalarization
79+
80+ return node
81+
82+
83+ class UltimateBCEPass (ast .NodeTransformer ):
84+ """
85+ Bounds Checking Elimination (BCE).
86+ Eliminates redundant bounds checks by statically evaluating trivially true conditions.
87+ """
88+
89+ def __init__ (self ):
90+ self .loop_bounds = {}
91+
92+ def visit_For (self , node ):
93+ if isinstance (node .target , ast .Name ):
94+ self .loop_bounds [node .target .id ] = True
95+ self .generic_visit (node )
96+ if isinstance (node .target , ast .Name ):
97+ self .loop_bounds .pop (node .target .id , None )
98+ return node
99+
100+ def visit_Compare (self , node ):
101+ node = self .generic_visit (node )
102+ if len (node .ops ) != 1 :
103+ return node
104+
105+ left = node .left
106+ op = node .ops [0 ]
107+ right = node .comparators [0 ]
108+ l_dump = ast .dump (left )
109+ r_dump = ast .dump (right )
110+ lv = get_const_val (left )
111+ rv = get_const_val (right )
112+
113+ if lv is not None and rv is not None :
114+ try :
115+ if isinstance (op , ast .Lt ):
116+ return ast .Constant (value = lv < rv )
117+ if isinstance (op , ast .GtE ):
118+ return ast .Constant (value = lv >= rv )
119+ except Exception :
120+ pass
121+
122+ is_loop_var = isinstance (left , ast .Name ) and left .id in self .loop_bounds
123+ if is_loop_var and isinstance (op , ast .GtE ) and rv == 0 :
124+ return ast .Constant (value = True )
125+ if is_loop_var and isinstance (op , ast .Lt ) and "FloorDiv" in r_dump and "size" in r_dump :
126+ return ast .Constant (value = True )
127+
128+ if isinstance (op , ast .GtE ) and rv == 0 :
129+ if "arange" in l_dump and "index" not in l_dump and "pid" not in l_dump :
130+ return ast .Constant (value = True )
131+ if isinstance (op , ast .Lt ) and "arange" in l_dump and "BLOCK_SIZE" in l_dump and "BLOCK_SIZE" in r_dump :
132+ if "size" not in r_dump :
133+ return ast .Constant (value = True )
134+
135+ return node
136+
137+
138+ class MaskCSEPass (ast .NodeTransformer ):
139+ """
140+ Mask-level Common Subexpression Elimination (CSE).
141+ Simplifies mask expressions by flattening nested BitAnd operations and removing duplicate conditions.
142+ """
143+
144+ def visit_BinOp (self , node ):
145+ node = self .generic_visit (node )
146+ if isinstance (node .op , ast .BitAnd ):
147+
148+ def flatten_and (n ):
149+ if isinstance (n , ast .BinOp ) and isinstance (n .op , ast .BitAnd ):
150+ return flatten_and (n .left ) + flatten_and (n .right )
151+ return [n ]
152+
153+ terms = flatten_and (node )
154+ unique_terms = []
155+ seen = set ()
156+ for t in terms :
157+ if getattr (t , "value" , None ) is True :
158+ continue
159+ d = ast .dump (t )
160+ if d not in seen :
161+ seen .add (d )
162+ unique_terms .append (t )
163+
164+ if not unique_terms :
165+ return ast .Constant (value = True )
166+
167+ res = unique_terms [0 ]
168+ for t in unique_terms [1 :]:
169+ res = ast .BinOp (left = res , op = ast .BitAnd (), right = t )
170+ return res
171+
172+ return node
173+
174+
175+ def apply_advanced_optimizations (tree ):
176+ ast .fix_missing_locations (tree )
177+ old_dump = ""
178+ # Iteratively apply optimizations until a fixed point is reached
179+ while old_dump != ast .dump (tree ):
180+ old_dump = ast .dump (tree )
181+ tree = SafeAlgebraPass ().visit (tree )
182+ tree = UltimateBCEPass ().visit (tree )
183+ tree = MaskCSEPass ().visit (tree )
184+ ast .fix_missing_locations (tree )
185+
186+ return tree
0 commit comments