Skip to content

Commit 5a48590

Browse files
committed
optimized dispatch
1 parent 75dd08c commit 5a48590

5 files changed

Lines changed: 593 additions & 3 deletions

File tree

de.peeeq.wurstscript/src/main/java/de/peeeq/wurstscript/WLogger.java

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@ public static void trace(Supplier<String> msgSupplier) {
4040
}
4141

4242
public static void info(String msg) {
43+
System.out.println(msg);
4344
instance.info(msg);
4445
}
4546

Lines changed: 339 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,339 @@
1+
package de.peeeq.wurstscript.intermediatelang.optimizer;
2+
3+
import de.peeeq.wurstscript.WurstOperator;
4+
import de.peeeq.wurstscript.jassIm.*;
5+
import de.peeeq.wurstscript.translation.imoptimizer.OptimizerPass;
6+
import de.peeeq.wurstscript.translation.imtranslation.ImTranslator;
7+
8+
import java.util.ArrayList;
9+
import java.util.IdentityHashMap;
10+
import java.util.Objects;
11+
import java.util.Set;
12+
13+
/**
14+
* Collapses consecutive identical dispatch safety checks into one.
15+
*
16+
* This targets the repetitive pattern emitted by checked dispatch after
17+
* inlining, where several method calls on the same receiver produce adjacent
18+
* copies of the same guard.
19+
*/
20+
public class DispatchCheckDeduplicator implements OptimizerPass {
21+
22+
private int rewrites;
23+
private final IdentityHashMap<ImFunction, IdentityHashMap<ImVar, Boolean>> mayWriteTypeIdMemo = new IdentityHashMap<>();
24+
25+
@Override
26+
public int optimize(ImTranslator trans) {
27+
rewrites = 0;
28+
mayWriteTypeIdMemo.clear();
29+
ImProg prog = trans.getImProg();
30+
for (ImFunction f : prog.getFunctions()) {
31+
optimizeStmts(f.getBody());
32+
}
33+
prog.flatten(trans);
34+
return rewrites;
35+
}
36+
37+
@Override
38+
public String getName() {
39+
return "Dispatch Check Dedup";
40+
}
41+
42+
private void optimizeStmts(ImStmts stmts) {
43+
for (ImStmt s : new ArrayList<>(stmts)) {
44+
if (s instanceof ImIf) {
45+
ImIf imIf = (ImIf) s;
46+
optimizeStmts(imIf.getThenBlock());
47+
optimizeStmts(imIf.getElseBlock());
48+
} else if (s instanceof ImLoop) {
49+
optimizeStmts(((ImLoop) s).getBody());
50+
} else if (s instanceof ImVarargLoop) {
51+
optimizeStmts(((ImVarargLoop) s).getBody());
52+
}
53+
}
54+
55+
int i = 0;
56+
while (i < stmts.size()) {
57+
GuardPattern first = extractDispatchGuard(stmts.get(i));
58+
if (first == null) {
59+
i++;
60+
continue;
61+
}
62+
63+
int j = i + 1;
64+
while (j < stmts.size()) {
65+
ImStmt s = stmts.get(j);
66+
GuardPattern next = extractDispatchGuard(s);
67+
if (next != null) {
68+
if (first.sameGuardAs(next)) {
69+
stmts.remove(j);
70+
rewrites++;
71+
continue;
72+
}
73+
// Different guard: keep scanning only if statement cannot invalidate this guard.
74+
if (invalidatesGuard(s, first)) {
75+
break;
76+
}
77+
j++;
78+
continue;
79+
}
80+
if (invalidatesGuard(s, first)) {
81+
break;
82+
}
83+
j++;
84+
}
85+
86+
i++;
87+
}
88+
}
89+
90+
private boolean invalidatesGuard(ImStmt s, GuardPattern guard) {
91+
if (s instanceof ImSet) {
92+
ImSet set = (ImSet) s;
93+
ImLExpr left = set.getLeft();
94+
if (left instanceof ImVarAccess) {
95+
ImVar v = ((ImVarAccess) left).getVar();
96+
return v == guard.failedCond.receiverVar || v == guard.failedCond.typeIdVar;
97+
}
98+
if (left instanceof ImVarArrayAccess) {
99+
ImVar v = ((ImVarArrayAccess) left).getVar();
100+
return v == guard.failedCond.typeIdVar;
101+
}
102+
if (left instanceof ImMemberAccess) {
103+
ImVar v = ((ImMemberAccess) left).getVar();
104+
return v == guard.failedCond.typeIdVar;
105+
}
106+
return false;
107+
}
108+
if (s instanceof ImFunctionCall) {
109+
ImFunction f = ((ImFunctionCall) s).getFunc();
110+
if (isKnownNonMutatingFunction(f)) {
111+
return false;
112+
}
113+
return mayWriteTypeId(f, guard.failedCond.typeIdVar);
114+
}
115+
if (s instanceof ImMethodCall) {
116+
ImMethod m = ((ImMethodCall) s).getMethod();
117+
return mayWriteTypeId(m.getImplementation(), guard.failedCond.typeIdVar);
118+
}
119+
if (s instanceof ImDealloc || s instanceof ImAlloc) {
120+
return true;
121+
}
122+
if (s instanceof ImIf || s instanceof ImLoop || s instanceof ImVarargLoop
123+
|| s instanceof ImReturn || s instanceof ImExitwhen) {
124+
return true;
125+
}
126+
return false;
127+
}
128+
129+
private boolean mayWriteTypeId(ImFunction f, ImVar typeIdVar) {
130+
if (f == null) {
131+
return true;
132+
}
133+
IdentityHashMap<ImVar, Boolean> byTypeId = mayWriteTypeIdMemo.computeIfAbsent(f, k -> new IdentityHashMap<>());
134+
Boolean memo = byTypeId.get(typeIdVar);
135+
if (memo != null) {
136+
return memo;
137+
}
138+
boolean result = mayWriteTypeIdImpl(f, typeIdVar, java.util.Collections.newSetFromMap(new IdentityHashMap<>()));
139+
byTypeId.put(typeIdVar, result);
140+
return result;
141+
}
142+
143+
private boolean isKnownNonMutatingFunction(ImFunction f) {
144+
if (f == null) return false;
145+
String n = f.getName();
146+
return "println".equals(n)
147+
|| "print".equals(n)
148+
|| "I2S".equals(n)
149+
|| "R2S".equals(n)
150+
|| "BJDebugMsg".equals(n);
151+
}
152+
153+
private boolean mayWriteTypeIdImpl(ImFunction f, ImVar typeIdVar, Set<ImFunction> visiting) {
154+
if (f.isNative()) {
155+
return true;
156+
}
157+
if (!visiting.add(f)) {
158+
return true;
159+
}
160+
final boolean[] writes = {false};
161+
f.accept(new ImFunction.DefaultVisitor() {
162+
@Override
163+
public void visit(ImSet e) {
164+
super.visit(e);
165+
ImLExpr left = e.getLeft();
166+
if (left instanceof ImVarArrayAccess && ((ImVarArrayAccess) left).getVar() == typeIdVar) {
167+
writes[0] = true;
168+
} else if (left instanceof ImVarAccess && ((ImVarAccess) left).getVar() == typeIdVar) {
169+
writes[0] = true;
170+
} else if (left instanceof ImMemberAccess && ((ImMemberAccess) left).getVar() == typeIdVar) {
171+
writes[0] = true;
172+
}
173+
}
174+
175+
@Override
176+
public void visit(ImFunctionCall e) {
177+
super.visit(e);
178+
if (!writes[0] && mayWriteTypeIdImpl(e.getFunc(), typeIdVar, visiting)) {
179+
writes[0] = true;
180+
}
181+
}
182+
183+
@Override
184+
public void visit(ImMethodCall e) {
185+
super.visit(e);
186+
if (!writes[0] && mayWriteTypeIdImpl(e.getMethod().getImplementation(), typeIdVar, visiting)) {
187+
writes[0] = true;
188+
}
189+
}
190+
191+
@Override
192+
public void visit(ImDealloc e) {
193+
super.visit(e);
194+
writes[0] = true;
195+
}
196+
});
197+
visiting.remove(f);
198+
return writes[0];
199+
}
200+
201+
private GuardPattern extractDispatchGuard(ImStmt stmt) {
202+
if (!(stmt instanceof ImIf)) {
203+
return null;
204+
}
205+
ImIf outer = (ImIf) stmt;
206+
if (!outer.getElseBlock().isEmpty() || outer.getThenBlock().size() != 1) {
207+
return null;
208+
}
209+
GuardCond failed = parseTypeIdZeroCond(outer.getCondition());
210+
if (failed == null) {
211+
return null;
212+
}
213+
214+
ImStmt innerStmt = outer.getThenBlock().get(0);
215+
if (!(innerStmt instanceof ImIf)) {
216+
return null;
217+
}
218+
ImIf inner = (ImIf) innerStmt;
219+
if (inner.getThenBlock().size() != 1 || inner.getElseBlock().size() != 1) {
220+
return null;
221+
}
222+
if (!isReceiverZeroCond(inner.getCondition(), failed.receiverVar)) {
223+
return null;
224+
}
225+
226+
ErrorCall nullErr = parseSingleErrorCall(inner.getThenBlock().get(0));
227+
ErrorCall invalidErr = parseSingleErrorCall(inner.getElseBlock().get(0));
228+
if (nullErr == null || invalidErr == null) {
229+
return null;
230+
}
231+
232+
return new GuardPattern(failed, nullErr, invalidErr);
233+
}
234+
235+
private static GuardCond parseTypeIdZeroCond(ImExpr expr) {
236+
if (!(expr instanceof ImOperatorCall)) {
237+
return null;
238+
}
239+
ImOperatorCall op = (ImOperatorCall) expr;
240+
if (op.getOp() != WurstOperator.EQ || op.getArguments().size() != 2) {
241+
return null;
242+
}
243+
ImExpr a = op.getArguments().get(0);
244+
ImExpr b = op.getArguments().get(1);
245+
GuardCond c = parseTypeIdEqZero(a, b);
246+
if (c != null) {
247+
return c;
248+
}
249+
return parseTypeIdEqZero(b, a);
250+
}
251+
252+
private static GuardCond parseTypeIdEqZero(ImExpr left, ImExpr right) {
253+
if (!(right instanceof ImIntVal) || ((ImIntVal) right).getValI() != 0) {
254+
return null;
255+
}
256+
if (!(left instanceof ImVarArrayAccess)) {
257+
return null;
258+
}
259+
ImVarArrayAccess aa = (ImVarArrayAccess) left;
260+
if (aa.getIndexes().size() != 1 || !(aa.getIndexes().get(0) instanceof ImVarAccess)) {
261+
return null;
262+
}
263+
ImVar receiver = ((ImVarAccess) aa.getIndexes().get(0)).getVar();
264+
return new GuardCond(aa.getVar(), receiver);
265+
}
266+
267+
private static boolean isReceiverZeroCond(ImExpr expr, ImVar receiver) {
268+
if (!(expr instanceof ImOperatorCall)) {
269+
return false;
270+
}
271+
ImOperatorCall op = (ImOperatorCall) expr;
272+
if (op.getOp() != WurstOperator.EQ || op.getArguments().size() != 2) {
273+
return false;
274+
}
275+
return isReceiverEqZero(op.getArguments().get(0), op.getArguments().get(1), receiver)
276+
|| isReceiverEqZero(op.getArguments().get(1), op.getArguments().get(0), receiver);
277+
}
278+
279+
private static boolean isReceiverEqZero(ImExpr left, ImExpr right, ImVar receiver) {
280+
return left instanceof ImVarAccess
281+
&& ((ImVarAccess) left).getVar() == receiver
282+
&& right instanceof ImIntVal
283+
&& ((ImIntVal) right).getValI() == 0;
284+
}
285+
286+
private static ErrorCall parseSingleErrorCall(ImStmt stmt) {
287+
if (!(stmt instanceof ImFunctionCall)) {
288+
return null;
289+
}
290+
ImFunctionCall fc = (ImFunctionCall) stmt;
291+
if (fc.getArguments().size() != 1 || !(fc.getArguments().get(0) instanceof ImStringVal)) {
292+
return null;
293+
}
294+
return new ErrorCall(fc.getFunc(), ((ImStringVal) fc.getArguments().get(0)).getValS());
295+
}
296+
297+
private static final class GuardPattern {
298+
private final GuardCond failedCond;
299+
private final ErrorCall nullError;
300+
private final ErrorCall invalidError;
301+
302+
private GuardPattern(GuardCond failedCond, ErrorCall nullError, ErrorCall invalidError) {
303+
this.failedCond = failedCond;
304+
this.nullError = nullError;
305+
this.invalidError = invalidError;
306+
}
307+
308+
private boolean sameGuardAs(GuardPattern other) {
309+
return failedCond.typeIdVar == other.failedCond.typeIdVar
310+
&& failedCond.receiverVar == other.failedCond.receiverVar
311+
&& nullError.sameAs(other.nullError)
312+
&& invalidError.sameAs(other.invalidError);
313+
}
314+
}
315+
316+
private static final class GuardCond {
317+
private final ImVar typeIdVar;
318+
private final ImVar receiverVar;
319+
320+
private GuardCond(ImVar typeIdVar, ImVar receiverVar) {
321+
this.typeIdVar = typeIdVar;
322+
this.receiverVar = receiverVar;
323+
}
324+
}
325+
326+
private static final class ErrorCall {
327+
private final ImFunction func;
328+
private final String message;
329+
330+
private ErrorCall(ImFunction func, String message) {
331+
this.func = func;
332+
this.message = message;
333+
}
334+
335+
private boolean sameAs(ErrorCall other) {
336+
return func == other.func && Objects.equals(message, other.message);
337+
}
338+
}
339+
}

de.peeeq.wurstscript/src/main/java/de/peeeq/wurstscript/translation/imoptimizer/ImOptimizer.java

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
import de.peeeq.wurstscript.WLogger;
77
import de.peeeq.wurstscript.intermediatelang.optimizer.BranchMerger;
88
import de.peeeq.wurstscript.intermediatelang.optimizer.ConstantAndCopyPropagation;
9+
import de.peeeq.wurstscript.intermediatelang.optimizer.DispatchCheckDeduplicator;
910
import de.peeeq.wurstscript.intermediatelang.optimizer.LocalMerger;
1011
import de.peeeq.wurstscript.intermediatelang.optimizer.SideEffectAnalyzer;
1112
import de.peeeq.wurstscript.intermediatelang.optimizer.SimpleRewrites;
@@ -32,6 +33,7 @@ public class ImOptimizer {
3233
localPasses.add(new ConstantAndCopyPropagation());
3334
localPasses.add(new UselessFunctionCallsRemover());
3435
localPasses.add(new GlobalsInliner());
36+
localPasses.add(new DispatchCheckDeduplicator());
3537
localPasses.add(new SimpleRewrites());
3638
}
3739

0 commit comments

Comments
 (0)