Skip to content

Commit 176f866

Browse files
committed
Fix Derivation Across Multiple Passes
Variable nodes should have origin nodes so when expressions are simplified in multiple passes, previous simplifications are preserved
1 parent 8db7a47 commit 176f866

File tree

3 files changed

+105
-46
lines changed

3 files changed

+105
-46
lines changed

liquidjava-verifier/src/main/java/liquidjava/rj_language/opt/ConstantPropagation.java

Lines changed: 77 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
import liquidjava.rj_language.opt.derivation_node.ValDerivationNode;
1111
import liquidjava.rj_language.opt.derivation_node.VarDerivationNode;
1212

13+
import java.util.HashMap;
1314
import java.util.Map;
1415

1516
public class ConstantPropagation {
@@ -19,64 +20,122 @@ public class ConstantPropagation {
1920
* VariableResolver to extract variable equalities from the expression first. Returns a derivation node representing
2021
* the propagation steps taken.
2122
*/
22-
public static ValDerivationNode propagate(Expression exp, DerivationNode defaultOrigin) {
23+
public static ValDerivationNode propagate(Expression exp, ValDerivationNode previousOrigin) {
2324
Map<String, Expression> substitutions = VariableResolver.resolve(exp);
24-
return propagateRecursive(exp, substitutions, defaultOrigin);
25+
26+
// map of variable origins from the previous derivation tree
27+
Map<String, DerivationNode> varOrigins = new HashMap<>();
28+
if (previousOrigin != null) {
29+
extractVarOrigins(previousOrigin, varOrigins);
30+
}
31+
return propagateRecursive(exp, substitutions, varOrigins);
2532
}
2633

2734
/**
2835
* Recursively performs constant propagation on an expression (e.g. x + y && x == 1 && y == 2 => 1 + 2)
2936
*/
3037
private static ValDerivationNode propagateRecursive(Expression exp, Map<String, Expression> subs,
31-
DerivationNode defaultOrigin) {
38+
Map<String, DerivationNode> varOrigins) {
3239

3340
// substitute variable
3441
if (exp instanceof Var var) {
3542
String name = var.getName();
3643
Expression value = subs.get(name);
3744
// substitution
38-
if (value != null)
39-
return new ValDerivationNode(value.clone(), new VarDerivationNode(name));
45+
if (value != null) {
46+
// check if this variable has an origin from a previous pass
47+
DerivationNode previousOrigin = varOrigins.get(name);
48+
49+
// preserve origin if value came from previous derivation
50+
DerivationNode origin = previousOrigin != null ? new VarDerivationNode(name, previousOrigin) : new VarDerivationNode(name);
51+
return new ValDerivationNode(value.clone(), origin);
52+
}
4053

4154
// no substitution
42-
return new ValDerivationNode(var, defaultOrigin);
55+
return new ValDerivationNode(var, null);
4356
}
4457

4558
// lift unary origin
4659
if (exp instanceof UnaryExpression unary) {
47-
ValDerivationNode operand = propagateRecursive(unary.getChildren().get(0), subs, defaultOrigin);
60+
ValDerivationNode operand = propagateRecursive(unary.getChildren().get(0), subs, varOrigins);
4861
UnaryExpression cloned = (UnaryExpression) unary.clone();
4962
cloned.setChild(0, operand.getValue());
5063

51-
DerivationNode origin = operand.getOrigin() != null ? new UnaryDerivationNode(operand, cloned.getOp())
52-
: defaultOrigin;
53-
return new ValDerivationNode(cloned, origin);
64+
return operand.getOrigin() != null
65+
? new ValDerivationNode(cloned, new UnaryDerivationNode(operand, cloned.getOp()))
66+
: new ValDerivationNode(cloned, null);
5467
}
5568

5669
// lift binary origin
5770
if (exp instanceof BinaryExpression binary) {
58-
ValDerivationNode left = propagateRecursive(binary.getFirstOperand(), subs, defaultOrigin);
59-
ValDerivationNode right = propagateRecursive(binary.getSecondOperand(), subs, defaultOrigin);
71+
ValDerivationNode left = propagateRecursive(binary.getFirstOperand(), subs, varOrigins);
72+
ValDerivationNode right = propagateRecursive(binary.getSecondOperand(), subs, varOrigins);
6073
BinaryExpression cloned = (BinaryExpression) binary.clone();
6174
cloned.setChild(0, left.getValue());
6275
cloned.setChild(1, right.getValue());
6376

64-
DerivationNode origin = (left.getOrigin() != null || right.getOrigin() != null)
65-
? new BinaryDerivationNode(left, right, cloned.getOperator()) : defaultOrigin;
66-
return new ValDerivationNode(cloned, origin);
77+
return (left.getOrigin() != null || right.getOrigin() != null)
78+
? new ValDerivationNode(cloned, new BinaryDerivationNode(left, right, cloned.getOperator()))
79+
: new ValDerivationNode(cloned, null);
6780
}
6881

6982
// recursively propagate children
7083
if (exp.hasChildren()) {
7184
Expression propagated = exp.clone();
7285
for (int i = 0; i < exp.getChildren().size(); i++) {
73-
ValDerivationNode child = propagateRecursive(exp.getChildren().get(i), subs, defaultOrigin);
86+
ValDerivationNode child = propagateRecursive(exp.getChildren().get(i), subs, varOrigins);
7487
propagated.setChild(i, child.getValue());
7588
}
76-
return new ValDerivationNode(propagated, defaultOrigin);
89+
return new ValDerivationNode(propagated, null);
7790
}
7891

7992
// no propagation
80-
return new ValDerivationNode(exp, defaultOrigin);
93+
return new ValDerivationNode(exp, null);
94+
}
95+
96+
97+
/**
98+
* Extracts the derivation nodes for variable values from the derivation tree
99+
* This is so done so when we find "var == value" in the tree, we store the derivation of the value
100+
* So it can be preserved when var is substituted in subsequent passes
101+
*/
102+
private static void extractVarOrigins(ValDerivationNode node, Map<String, DerivationNode> varOrigins) {
103+
if (node == null)
104+
return;
105+
106+
Expression value = node.getValue();
107+
DerivationNode origin = node.getOrigin();
108+
109+
// check for equality expressions
110+
if (value instanceof BinaryExpression binExp && "==".equals(binExp.getOperator())
111+
&& origin instanceof BinaryDerivationNode binOrigin) {
112+
Expression left = binExp.getFirstOperand();
113+
Expression right = binExp.getSecondOperand();
114+
115+
// extract variable name and value derivation from either side
116+
String varName = null;
117+
ValDerivationNode valueDerivation = null;
118+
119+
if (left instanceof Var var && right.isLiteral()) {
120+
varName = var.getName();
121+
valueDerivation = binOrigin.getRight();
122+
} else if (right instanceof Var var && left.isLiteral()) {
123+
varName = var.getName();
124+
valueDerivation = binOrigin.getLeft();
125+
}
126+
if (varName != null && valueDerivation != null && valueDerivation.getOrigin() != null) {
127+
varOrigins.put(varName, valueDerivation.getOrigin());
128+
}
129+
}
130+
131+
// recursively process the origin tree
132+
if (origin instanceof BinaryDerivationNode binOrigin) {
133+
extractVarOrigins(binOrigin.getLeft(), varOrigins);
134+
extractVarOrigins(binOrigin.getRight(), varOrigins);
135+
} else if (origin instanceof UnaryDerivationNode unaryOrigin) {
136+
extractVarOrigins(unaryOrigin.getOperand(), varOrigins);
137+
} else if (origin instanceof ValDerivationNode valOrigin) {
138+
extractVarOrigins(valOrigin, varOrigins);
139+
}
81140
}
82141
}

liquidjava-verifier/src/main/java/liquidjava/rj_language/opt/derivation_node/VarDerivationNode.java

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,12 +3,23 @@
33
public class VarDerivationNode extends DerivationNode {
44

55
private final String var;
6+
private final DerivationNode origin;
67

78
public VarDerivationNode(String var) {
89
this.var = var;
10+
this.origin = null;
11+
}
12+
13+
public VarDerivationNode(String var, DerivationNode origin) {
14+
this.var = var;
15+
this.origin = origin;
916
}
1017

1118
public String getVar() {
1219
return var;
1320
}
21+
22+
public DerivationNode getOrigin() {
23+
return origin;
24+
}
1425
}

liquidjava-verifier/src/test/java/liquidjava/rj_language/opt/ExpressionSimplifierTest.java

Lines changed: 17 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -335,35 +335,24 @@ void testFixedPointSimplification() {
335335

336336
// Compare derivation tree structure
337337

338-
// Origin of y (value 2) - right operand of result
339-
ValDerivationNode originY = new ValDerivationNode(new LiteralInt(2), new VarDerivationNode("y"));
340-
UnaryDerivationNode originNeg2 = new UnaryDerivationNode(originY, "-");
341-
ValDerivationNode rightNode = new ValDerivationNode(new LiteralInt(-2), originNeg2);
338+
// Build the derivation chain for the right side:
339+
// 6 came from a, 3 came from b
340+
ValDerivationNode val6FromA = new ValDerivationNode(new LiteralInt(6), new VarDerivationNode("a"));
341+
ValDerivationNode val3FromB = new ValDerivationNode(new LiteralInt(3), new VarDerivationNode("b"));
342342

343-
// Origin of x - left operand of result
344-
// 6 (from a) / 3 (from b) -> 2
345-
ValDerivationNode val6 = new ValDerivationNode(new LiteralInt(6), new VarDerivationNode("a"));
346-
ValDerivationNode val3 = new ValDerivationNode(new LiteralInt(3), new VarDerivationNode("b"));
347-
BinaryDerivationNode divOp = new BinaryDerivationNode(val6, val3, "/");
348-
ValDerivationNode val2FromDiv = new ValDerivationNode(new LiteralInt(2), divOp);
349-
350-
// y == 2 (from y == 6 / 3)
351-
ValDerivationNode valYNode = new ValDerivationNode(new Var("y"), null);
352-
BinaryDerivationNode eqYOp = new BinaryDerivationNode(valYNode, val2FromDiv, "==");
353-
ValDerivationNode yEq2 = new ValDerivationNode(new BinaryExpression(new Var("y"), "==", new LiteralInt(2)),
354-
eqYOp);
355-
356-
// x == -y
357-
ValDerivationNode xEqNegY = new ValDerivationNode(
358-
new BinaryExpression(new Var("x"), "==", new UnaryExpression("-", new Var("y"))), null);
359-
360-
// x == -y && y == 2
361-
BinaryDerivationNode andOp1 = new BinaryDerivationNode(xEqNegY, yEq2, "&&");
362-
ValDerivationNode xEqNegYAndYEq2 = new ValDerivationNode(
363-
new BinaryExpression(xEqNegY.getValue(), "&&", yEq2.getValue()), andOp1);
364-
365-
// Left node x has origin pointing to the previous simplification's tree
366-
ValDerivationNode leftNode = new ValDerivationNode(new Var("x"), xEqNegYAndYEq2);
343+
// 6 / 3 -> 2
344+
BinaryDerivationNode divOrigin = new BinaryDerivationNode(val6FromA, val3FromB, "/");
345+
346+
// 2 came from y, and y's value came from 6 / 2
347+
VarDerivationNode yChainedOrigin = new VarDerivationNode("y", divOrigin);
348+
ValDerivationNode val2FromY = new ValDerivationNode(new LiteralInt(2), yChainedOrigin);
349+
350+
// -2
351+
UnaryDerivationNode negOrigin = new UnaryDerivationNode(val2FromY, "-");
352+
ValDerivationNode rightNode = new ValDerivationNode(new LiteralInt(-2), negOrigin);
353+
354+
// Left node x has no origin
355+
ValDerivationNode leftNode = new ValDerivationNode(new Var("x"), null);
367356

368357
// Root equality
369358
BinaryDerivationNode rootOrigin = new BinaryDerivationNode(leftNode, rightNode, "==");

0 commit comments

Comments
 (0)