Skip to content

Commit c6fb474

Browse files
committed
Add mem2reg, LICM, and deduplicated optimizer rule files
Add dedicated mem2reg.egg and licm.egg rule files extracted from other rule files. Add disabled rule files (merge_return, loop_fusion, legalization, instrumentation) with explanatory comments. Expand datatypes.egg with legalization types. Add 15 unit tests for mem2reg and LICM passes.
1 parent 8895848 commit c6fb474

File tree

8 files changed

+1802
-0
lines changed

8 files changed

+1802
-0
lines changed

rust/spirv-tools-opt/src/egglog_opt.rs

Lines changed: 460 additions & 0 deletions
Large diffs are not rendered by default.

rust/spirv-tools-opt/src/rules/datatypes.egg

Lines changed: 410 additions & 0 deletions
Large diffs are not rendered by default.
Lines changed: 276 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,276 @@
1+
; =============================================================================
2+
; Instrumentation Passes for E-Graph Optimization
3+
; =============================================================================
4+
;
5+
; Instrumentation adds profiling, debugging, and validation code to shaders.
6+
; In an e-graph, we represent both instrumented and uninstrumented forms.
7+
;
8+
; Key insight: Instrumentation rules ADD nodes rather than simplify.
9+
; During extraction, we can choose whether to include instrumentation
10+
; based on a build flag or debug mode.
11+
;
12+
; Instrumentation types:
13+
; - DebugPrintf: Print values for debugging
14+
; - ProfilingCounter: Count executions
15+
; - BoundsCheck: Validate array indices
16+
; - NaNCheck: Detect floating-point NaNs
17+
; - InfinityCheck: Detect floating-point infinities
18+
; - OverflowCheck: Detect integer overflow
19+
;
20+
; =============================================================================
21+
; SECTION 1: Debug Printf Instrumentation
22+
; =============================================================================
23+
24+
; Add debug print capability to any value
25+
; Instrumented(val) includes a debug output alongside the computation
26+
27+
; Scalar debug output
28+
(rule ((= e (Instrumented val)))
29+
((union e (Seq (DebugPrintf "%f" val) val))))
30+
31+
; Vector debug output
32+
(rule ((= e (InstrumentedVec2 val)))
33+
((union e (Seq (DebugPrintf "vec2(%f, %f)" (VecExtract val 0) (VecExtract val 1)) val))))
34+
(rule ((= e (InstrumentedVec3 val)))
35+
((union e (Seq (DebugPrintf "vec3(%f, %f, %f)"
36+
(VecExtract val 0) (VecExtract val 1) (VecExtract val 2)) val))))
37+
(rule ((= e (InstrumentedVec4 val)))
38+
((union e (Seq (DebugPrintf "vec4(%f, %f, %f, %f)"
39+
(VecExtract val 0) (VecExtract val 1) (VecExtract val 2) (VecExtract val 3)) val))))
40+
41+
; Named debug output
42+
(rule ((= e (InstrumentedNamed name val)))
43+
((union e (Seq (DebugPrintf name val) val))))
44+
45+
; =============================================================================
46+
; SECTION 2: Profiling Counters
47+
; =============================================================================
48+
49+
; Count how many times a code path is executed
50+
51+
; Increment profiling counter
52+
(rule ((= e (ProfiledBlock id body)))
53+
((union e (Seq (AtomicIAdd (ProfileCounter id) (Const 1)) body))))
54+
55+
; Profile loop iterations
56+
(rule ((= e (ProfiledLoop id (Theta cond body init))))
57+
((union e (Theta cond
58+
(Seq (AtomicIAdd (ProfileCounter id) (Const 1)) body)
59+
init))))
60+
61+
; Profile branch taken
62+
(rule ((= e (ProfiledBranch id (Gamma cond t f))))
63+
((union e (Gamma cond
64+
(Seq (AtomicIAdd (ProfileCounter id) (Const 1)) t)
65+
(Seq (AtomicIAdd (ProfileCounter (Add id 1)) (Const 1)) f)))))
66+
67+
; =============================================================================
68+
; SECTION 3: Bounds Checking
69+
; =============================================================================
70+
71+
; Add bounds checks to array accesses
72+
73+
; Check array index in bounds
74+
(rule ((= e (BoundsChecked (AccessChain1 arr idx) size)))
75+
((union e (Seq
76+
(Assert (ULt idx size) "Array index out of bounds")
77+
(AccessChain1 arr idx)))))
78+
79+
; Multi-dimensional bounds check
80+
(rule ((= e (BoundsChecked2D (AccessChain2 arr i j) rows cols)))
81+
((union e (Seq
82+
(Assert (LogAnd (ULt i rows) (ULt j cols)) "2D index out of bounds")
83+
(AccessChain2 arr i j)))))
84+
85+
; Buffer bounds check (byte offset)
86+
(rule ((= e (BufferBoundsChecked ptr offset size)))
87+
((union e (Seq
88+
(Assert (ULe (Add offset (Const 4)) size) "Buffer access out of bounds")
89+
(Load ptr)))))
90+
91+
; =============================================================================
92+
; SECTION 4: NaN/Infinity Detection
93+
; =============================================================================
94+
95+
; Detect NaN in floating-point computations
96+
97+
; Check for NaN after operation
98+
(rule ((= e (NaNChecked val)))
99+
((union e (Seq
100+
(Assert (LogNot (IsNan val)) "NaN detected")
101+
val))))
102+
103+
; Check for infinity
104+
(rule ((= e (InfChecked val)))
105+
((union e (Seq
106+
(Assert (LogNot (IsInf val)) "Infinity detected")
107+
val))))
108+
109+
; Check for both NaN and infinity
110+
(rule ((= e (FiniteChecked val)))
111+
((union e (Seq
112+
(Assert (LogAnd (LogNot (IsNan val)) (LogNot (IsInf val))) "Non-finite value")
113+
val))))
114+
115+
; Vector NaN check
116+
(rule ((= e (NaNCheckedVec val)))
117+
((union e (Seq
118+
(Assert (LogNot (Any (IsNan val))) "NaN in vector")
119+
val))))
120+
121+
; =============================================================================
122+
; SECTION 5: Integer Overflow Detection
123+
; =============================================================================
124+
125+
; Detect overflow in integer arithmetic
126+
127+
; Signed addition overflow
128+
(rule ((= e (OverflowCheckedAdd a b)))
129+
((union e (Seq
130+
(Assert (LogNot (AddOverflows a b)) "Integer overflow in addition")
131+
(Add a b)))))
132+
133+
; Signed multiplication overflow
134+
(rule ((= e (OverflowCheckedMul a b)))
135+
((union e (Seq
136+
(Assert (LogNot (MulOverflows a b)) "Integer overflow in multiplication")
137+
(Mul a b)))))
138+
139+
; Signed subtraction overflow
140+
(rule ((= e (OverflowCheckedSub a b)))
141+
((union e (Seq
142+
(Assert (LogNot (SubOverflows a b)) "Integer overflow in subtraction")
143+
(Sub a b)))))
144+
145+
; =============================================================================
146+
; SECTION 6: Division by Zero Detection
147+
; =============================================================================
148+
149+
; Check for division by zero
150+
151+
(rule ((= e (DivChecked a b)))
152+
((union e (Seq
153+
(Assert (Ne b (Const 0)) "Division by zero")
154+
(SDiv a b)))))
155+
156+
(rule ((= e (UDivChecked a b)))
157+
((union e (Seq
158+
(Assert (Ne b (Const 0)) "Division by zero")
159+
(UDiv a b)))))
160+
161+
(rule ((= e (FDivChecked a b)))
162+
((union e (Seq
163+
(Assert (FNe b (FConst 0.0)) "Division by zero")
164+
(FDiv a b)))))
165+
166+
; =============================================================================
167+
; SECTION 7: Null Pointer Checks
168+
; =============================================================================
169+
170+
; Check for null/invalid pointers
171+
172+
(rule ((= e (NullChecked ptr)))
173+
((union e (Seq
174+
(Assert (Ne ptr (NullPtr)) "Null pointer dereference")
175+
ptr))))
176+
177+
(rule ((= e (NullCheckedLoad ptr mem)))
178+
((union e (Seq
179+
(Assert (Ne ptr (NullPtr)) "Null pointer load")
180+
(Load ptr mem)))))
181+
182+
; =============================================================================
183+
; SECTION 8: Shader Validation
184+
; =============================================================================
185+
186+
; Validate shader-specific constraints
187+
188+
; Validate texture coordinates in range [0, 1]
189+
(rule ((= e (ValidatedTexCoord coord)))
190+
((union e (Seq
191+
(Assert (LogAnd (FGe coord (FConst 0.0)) (FLe coord (FConst 1.0)))
192+
"Texture coordinate out of range")
193+
coord))))
194+
195+
; Validate normal vector is normalized
196+
(rule ((= e (ValidatedNormal n)))
197+
((union e (Seq
198+
(Assert (FApproxEq (Length n) (FConst 1.0)) "Normal not normalized")
199+
n))))
200+
201+
; Validate matrix is orthogonal
202+
(rule ((= e (ValidatedOrthogonal m)))
203+
((union e (Seq
204+
(Assert (FApproxEq (Determinant m) (FConst 1.0)) "Matrix not orthogonal")
205+
m))))
206+
207+
; =============================================================================
208+
; SECTION 9: Timing Instrumentation
209+
; =============================================================================
210+
211+
; Measure execution time of code blocks
212+
213+
(rule ((= e (Timed id body)))
214+
((union e (Seq
215+
(AtomicIAdd (TimerStart id) (ReadClock))
216+
(Seq body
217+
(AtomicIAdd (TimerEnd id) (ReadClock)))))))
218+
219+
; =============================================================================
220+
; SECTION 10: Memory Access Pattern Tracking
221+
; =============================================================================
222+
223+
; Track memory access patterns for optimization hints
224+
225+
(rule ((= e (TrackedLoad ptr mem)))
226+
((union e (Seq
227+
(AtomicIAdd (AccessCounter (PtrToInt ptr)) (Const 1))
228+
(Load ptr mem)))))
229+
230+
(rule ((= e (TrackedStore ptr val mem)))
231+
((union e (Seq
232+
(AtomicIAdd (StoreCounter (PtrToInt ptr)) (Const 1))
233+
(StoreMem ptr val mem)))))
234+
235+
; =============================================================================
236+
; SECTION 11: Stripping Instrumentation
237+
; =============================================================================
238+
239+
; Rules to remove instrumentation for release builds
240+
; These are the "inverse" rules - stripping out the debug code
241+
242+
; Strip debug prints
243+
(rule ((= e (Seq (DebugPrintf _ _) val)))
244+
((union e val)))
245+
246+
; Strip profiling
247+
(rule ((= e (Seq (AtomicIAdd (ProfileCounter _) _) body)))
248+
((union e body)))
249+
250+
; Strip assertions (in release mode)
251+
(rule ((= e (Seq (Assert _ _) val)))
252+
((union e val)))
253+
254+
; Strip timing
255+
(rule ((= e (Seq (AtomicIAdd (TimerStart _) _) (Seq body (AtomicIAdd (TimerEnd _) _)))))
256+
((union e body)))
257+
258+
; =============================================================================
259+
; SECTION 12: Conditional Instrumentation
260+
; =============================================================================
261+
262+
; Enable instrumentation only under certain conditions
263+
264+
; Debug mode flag
265+
(rule ((= e (DebugMode (Instrumented val))))
266+
((union e (Gamma (DebugEnabled) (Instrumented val) val))))
267+
268+
; Verbose mode for extra output
269+
(rule ((= e (VerboseMode body)))
270+
((union e (Gamma (VerboseEnabled) body (Pure)))))
271+
272+
; Sample-based profiling (only instrument some invocations)
273+
(rule ((= e (SampledProfile rate body)))
274+
((union e (Gamma (Eq (Mod (InvocationId) rate) (Const 0))
275+
(ProfiledBlock body)
276+
body))))

0 commit comments

Comments
 (0)