Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import com.google.common.collect.Sets;
import de.peeeq.wurstscript.jassIm.*;
import de.peeeq.wurstscript.translation.imtranslation.*;
import de.peeeq.wurstscript.types.TypesHelper;

import java.util.*;

Expand Down Expand Up @@ -104,7 +105,7 @@ private void inlineCall(ImFunction f, Element parent, int parentI, ImFunctionCal
if (called == f) {
throw new Error("cannot inline self.");
}
List<ImStmt> stmts = Lists.newArrayList();
List<ImStmt> prefixStmts = Lists.newArrayList();
// save arguments to temp vars:
List<ImExpr> args = call.getArguments().removeAll();
Map<ImVar, ImVar> varSubtitutions = Maps.newLinkedHashMap();
Expand All @@ -115,7 +116,7 @@ private void inlineCall(ImFunction f, Element parent, int parentI, ImFunctionCal
f.getLocals().add(tempVar);
varSubtitutions.put(param, tempVar);
// set temp var
stmts.add(JassIm.ImSet(arg.attrTrace(), JassIm.ImVarAccess(tempVar), arg));
prefixStmts.add(JassIm.ImSet(arg.attrTrace(), JassIm.ImVarAccess(tempVar), arg));
}
// add locals
for (ImVar l : called.getLocals()) {
Expand All @@ -124,6 +125,7 @@ private void inlineCall(ImFunction f, Element parent, int parentI, ImFunctionCal
varSubtitutions.put(l, newL);
}
// add body and replace params with tempvars
List<ImStmt> copiedBody = Lists.newArrayList();
for (int i = 0; i < called.getBody().size(); i++) {
ImStmt s = called.getBody().get(i).copy();
ImHelper.replaceVar(s, varSubtitutions);
Expand All @@ -138,22 +140,48 @@ public void visit(ImFunctionCall called) {
});


stmts.add(s);
copiedBody.add(s);
}
// handle return

List<ImStmt> stmts = Lists.newArrayList();
stmts.addAll(prefixStmts);

ImExpr newExpr = null;
if (stmts.size() > 0) {
ImStmt lastStmt = stmts.get(stmts.size() - 1);
if (lastStmt instanceof ImReturn) {
ImReturn ret = (ImReturn) lastStmt;
stmts.remove(stmts.size() - 1);
ImExprOpt valOpt = ret.getReturnValue();
if (valOpt instanceof ImExpr) {
ImExpr val = (ImExpr) valOpt.copy();
ImHelper.replaceVar(val, varSubtitutions);
newExpr = ImStatementExpr(ImStmts(stmts), val);
if (maxOneReturn(called)) {
// Fast path for existing single-return shape.
stmts.addAll(copiedBody);
if (!stmts.isEmpty()) {
ImStmt lastStmt = stmts.get(stmts.size() - 1);
if (lastStmt instanceof ImReturn) {
ImReturn ret = (ImReturn) lastStmt;
stmts.remove(stmts.size() - 1);
ImExprOpt valOpt = ret.getReturnValue();
if (valOpt instanceof ImExpr) {
ImExpr val = (ImExpr) valOpt.copy();
ImHelper.replaceVar(val, varSubtitutions);
newExpr = ImStatementExpr(ImStmts(stmts), val);
}
}
}
} else {
// Multi-return path: rewrite returns to done-flag + optional return temp.
ImVar doneVar = JassIm.ImVar(call.attrTrace(), TypesHelper.imBool(), "inlineDone", false);
f.getLocals().add(doneVar);
stmts.add(JassIm.ImSet(call.attrTrace(), JassIm.ImVarAccess(doneVar), JassIm.ImBoolVal(false)));

ImVar retVar = null;
if (!(called.getReturnType() instanceof ImVoid)) {
retVar = JassIm.ImVar(call.attrTrace(), called.getReturnType().copy(), "inlineRet", false);
f.getLocals().add(retVar);
stmts.add(JassIm.ImSet(call.attrTrace(), JassIm.ImVarAccess(retVar), ImHelper.defaultValueForComplexType(called.getReturnType())));
}

ImStmts rewritten = rewriteForEarlyReturns(JassIm.ImStmts(copiedBody), doneVar, retVar);
stmts.addAll(rewritten.removeAll());

if (retVar != null) {
newExpr = ImStatementExpr(ImStmts(stmts), JassIm.ImVarAccess(retVar));
}
}
if (newExpr == null) {
newExpr = ImHelper.statementExprVoid(ImStmts(stmts));
Expand All @@ -162,6 +190,48 @@ public void visit(ImFunctionCall called) {

}

private ImStmts rewriteForEarlyReturns(ImStmts body, ImVar doneVar, ImVar retVar) {
ImStmts rewritten = JassIm.ImStmts();
for (ImStmt s : body) {
ImStmt transformed = rewriteStmtForEarlyReturn(s, doneVar, retVar);
ImExpr notDone = JassIm.ImOperatorCall(de.peeeq.wurstscript.WurstOperator.NOT, JassIm.ImExprs(JassIm.ImVarAccess(doneVar)));
rewritten.add(JassIm.ImIf(s.attrTrace(), notDone, JassIm.ImStmts(transformed), JassIm.ImStmts()));
}
return rewritten;
}

private ImStmt rewriteStmtForEarlyReturn(ImStmt s, ImVar doneVar, ImVar retVar) {
if (s instanceof ImReturn) {
ImReturn r = (ImReturn) s;
ImStmts b = JassIm.ImStmts();
if (retVar != null && r.getReturnValue() instanceof ImExpr) {
ImExpr rv = (ImExpr) r.getReturnValue();
rv.setParent(null);
b.add(JassIm.ImSet(r.getTrace(), JassIm.ImVarAccess(retVar), rv));
}
b.add(JassIm.ImSet(r.getTrace(), JassIm.ImVarAccess(doneVar), JassIm.ImBoolVal(true)));
return ImHelper.statementExprVoid(b);
} else if (s instanceof ImIf) {
ImIf imIf = (ImIf) s;
ImStmts thenBlock = rewriteForEarlyReturns(imIf.getThenBlock().copy(), doneVar, retVar);
ImStmts elseBlock = rewriteForEarlyReturns(imIf.getElseBlock().copy(), doneVar, retVar);
return JassIm.ImIf(imIf.getTrace(), imIf.getCondition().copy(), thenBlock, elseBlock);
} else if (s instanceof ImLoop) {
ImLoop l = (ImLoop) s;
ImStmts loopBody = JassIm.ImStmts();
loopBody.add(JassIm.ImExitwhen(l.getTrace(), JassIm.ImVarAccess(doneVar)));
loopBody.addAll(rewriteForEarlyReturns(l.getBody().copy(), doneVar, retVar).removeAll());
return JassIm.ImLoop(l.getTrace(), loopBody);
} else if (s instanceof ImVarargLoop) {
ImVarargLoop l = (ImVarargLoop) s;
ImStmts loopBody = JassIm.ImStmts();
loopBody.add(JassIm.ImExitwhen(l.getTrace(), JassIm.ImVarAccess(doneVar)));
loopBody.addAll(rewriteForEarlyReturns(l.getBody().copy(), doneVar, retVar).removeAll());
return JassIm.ImVarargLoop(l.getTrace(), loopBody, l.getLoopVar());
}
return s;
}

private void rateInlinableFunctions() {
for (Map.Entry<ImFunction, ImFunction> f : translator.getCalledFunctions().entries()) {
incCallCount(f.getKey());
Expand Down Expand Up @@ -288,9 +358,7 @@ private void collectInlinableFunctions() {
// this is only relevant for lua, because in JASS they are eliminated before inlining
continue;
}
if (maxOneReturn(f)) {
inlinableFunctions.add(f);
}
inlinableFunctions.add(f);
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -903,6 +903,28 @@ public void testInlineAnnotation() throws IOException {
assertTrue(inlined.contains("function noot"));
}

@Test
public void inlinerSupportsMultiReturn() throws IOException {
testAssertOkLines(true,
"package test",
"native testSuccess()",
"function absLike(int x) returns int",
" if x >= 0",
" return x",
" return 0 - x",
"init",
" let a = absLike(-4)",
" let b = absLike(3)",
" if a == 4 and b == 3",
" testSuccess()",
"endpackage"
);

String inlined = Files.toString(new File("test-output/OptimizerTests_inlinerSupportsMultiReturn_inl.j"), Charsets.UTF_8);
assertFalse(inlined.contains("call absLike"),
Comment thread
Frotty marked this conversation as resolved.
"Expected multi-return function calls to be inlined in _inl output.");
}


@Test
public void moveTowardsBug() { // see #737
Expand Down
Loading