Skip to content

Commit 6963a61

Browse files
authored
Expression Oversimplification (#179)
1 parent ed68be2 commit 6963a61

File tree

2 files changed

+229
-23
lines changed

2 files changed

+229
-23
lines changed

liquidjava-verifier/src/main/java/liquidjava/rj_language/opt/ExpressionSimplifier.java

Lines changed: 67 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,10 @@
33
import liquidjava.rj_language.ast.BinaryExpression;
44
import liquidjava.rj_language.ast.Expression;
55
import liquidjava.rj_language.ast.LiteralBoolean;
6+
import liquidjava.rj_language.ast.UnaryExpression;
67
import liquidjava.rj_language.opt.derivation_node.BinaryDerivationNode;
78
import liquidjava.rj_language.opt.derivation_node.DerivationNode;
9+
import liquidjava.rj_language.opt.derivation_node.UnaryDerivationNode;
810
import liquidjava.rj_language.opt.derivation_node.ValDerivationNode;
911

1012
public class ExpressionSimplifier {
@@ -15,12 +17,13 @@ public class ExpressionSimplifier {
1517
*/
1618
public static ValDerivationNode simplify(Expression exp) {
1719
ValDerivationNode fixedPoint = simplifyToFixedPoint(null, exp);
18-
return simplifyValDerivationNode(fixedPoint);
20+
ValDerivationNode simplified = simplifyValDerivationNode(fixedPoint);
21+
return unwrapBooleanLiterals(simplified);
1922
}
2023

2124
/**
2225
* Recursively applies propagation and folding until the expression stops changing (fixed point) Stops early if the
23-
* expression simplifies to 'true', which means we've simplified too much
26+
* expression simplifies to a boolean literal, which means we've simplified too much
2427
*/
2528
private static ValDerivationNode simplifyToFixedPoint(ValDerivationNode current, Expression prevExp) {
2629
// apply propagation and folding
@@ -34,6 +37,11 @@ private static ValDerivationNode simplifyToFixedPoint(ValDerivationNode current,
3437
return current;
3538
}
3639

40+
// prevent oversimplification
41+
if (current != null && currExp instanceof LiteralBoolean && !(current.getValue() instanceof LiteralBoolean)) {
42+
return current;
43+
}
44+
3745
// continue simplifying
3846
return simplifyToFixedPoint(simplified, simplified.getValue());
3947
}
@@ -114,4 +122,61 @@ private static boolean isRedundant(Expression exp) {
114122
}
115123
return false;
116124
}
125+
126+
/**
127+
* Recursively traverses the derivation tree and replaces boolean literals with the expressions that produced them,
128+
* but only when at least one operand in the derivation is non-boolean. e.g. "x == true" where true came from "1 >
129+
* 0" becomes "x == 1 > 0"
130+
*/
131+
private static ValDerivationNode unwrapBooleanLiterals(ValDerivationNode node) {
132+
Expression value = node.getValue();
133+
DerivationNode origin = node.getOrigin();
134+
135+
if (origin == null)
136+
return node;
137+
138+
// unwrap binary expressions
139+
if (value instanceof BinaryExpression binExp && origin instanceof BinaryDerivationNode binOrigin) {
140+
ValDerivationNode left = unwrapBooleanLiterals(binOrigin.getLeft());
141+
ValDerivationNode right = unwrapBooleanLiterals(binOrigin.getRight());
142+
if (left != binOrigin.getLeft() || right != binOrigin.getRight()) {
143+
Expression newValue = new BinaryExpression(left.getValue(), binExp.getOperator(), right.getValue());
144+
return new ValDerivationNode(newValue, new BinaryDerivationNode(left, right, binOrigin.getOp()));
145+
}
146+
return node;
147+
}
148+
149+
// unwrap unary expressions
150+
if (value instanceof UnaryExpression unaryExp && origin instanceof UnaryDerivationNode unaryOrigin) {
151+
ValDerivationNode operand = unwrapBooleanLiterals(unaryOrigin.getOperand());
152+
if (operand != unaryOrigin.getOperand()) {
153+
Expression newValue = new UnaryExpression(unaryExp.getOp(), operand.getValue());
154+
return new ValDerivationNode(newValue, new UnaryDerivationNode(operand, unaryOrigin.getOp()));
155+
}
156+
return node;
157+
}
158+
159+
// boolean literal with binary origin: unwrap if at least one child is non-boolean
160+
if (value instanceof LiteralBoolean && origin instanceof BinaryDerivationNode binOrigin) {
161+
ValDerivationNode left = unwrapBooleanLiterals(binOrigin.getLeft());
162+
ValDerivationNode right = unwrapBooleanLiterals(binOrigin.getRight());
163+
if (!(left.getValue() instanceof LiteralBoolean) || !(right.getValue() instanceof LiteralBoolean)) {
164+
Expression newValue = new BinaryExpression(left.getValue(), binOrigin.getOp(), right.getValue());
165+
return new ValDerivationNode(newValue, new BinaryDerivationNode(left, right, binOrigin.getOp()));
166+
}
167+
return node;
168+
}
169+
170+
// boolean literal with unary origin: unwrap if operand is non-boolean
171+
if (value instanceof LiteralBoolean && origin instanceof UnaryDerivationNode unaryOrigin) {
172+
ValDerivationNode operand = unwrapBooleanLiterals(unaryOrigin.getOperand());
173+
if (!(operand.getValue() instanceof LiteralBoolean)) {
174+
Expression newValue = new UnaryExpression(unaryOrigin.getOp(), operand.getValue());
175+
return new ValDerivationNode(newValue, new UnaryDerivationNode(operand, unaryOrigin.getOp()));
176+
}
177+
return node;
178+
}
179+
180+
return node;
181+
}
117182
}

liquidjava-verifier/src/test/java/liquidjava/rj_language/opt/ExpressionSimplifierTest.java

Lines changed: 162 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -243,13 +243,13 @@ void testComplexArithmeticWithMultipleOperations() {
243243
// When
244244
ValDerivationNode result = ExpressionSimplifier.simplify(fullExpression);
245245

246-
// Then
246+
// Then: boolean literals are unwrapped to show the verified conditions
247247
assertNotNull(result, "Result should not be null");
248248
assertNotNull(result.getValue(), "Result value should not be null");
249-
assertInstanceOf(LiteralBoolean.class, result.getValue(), "Result should be a boolean literal");
250-
assertTrue(result.getValue().isBooleanTrue(), "Expected result to be true");
249+
assertEquals("14 == 14 && 5 == 5 && 7 == 7 && 14 == 14", result.getValue().toString(),
250+
"All verified conditions should be visible instead of collapsed to true");
251251

252-
// 5 * 2 + 7 - 3
252+
// 5 * 2 + 7 - 3 = 14
253253
ValDerivationNode val5 = new ValDerivationNode(new LiteralInt(5), new VarDerivationNode("a"));
254254
ValDerivationNode val2 = new ValDerivationNode(new LiteralInt(2), null);
255255
BinaryDerivationNode mult5Times2 = new BinaryDerivationNode(val5, val2, "*");
@@ -266,39 +266,45 @@ void testComplexArithmeticWithMultipleOperations() {
266266
// 14 from variable c
267267
ValDerivationNode val14Right = new ValDerivationNode(new LiteralInt(14), new VarDerivationNode("c"));
268268

269-
// 14 == 14
269+
// 14 == 14 (unwrapped from true)
270270
BinaryDerivationNode compare14 = new BinaryDerivationNode(val14Left, val14Right, "==");
271-
ValDerivationNode trueFromComparison = new ValDerivationNode(new LiteralBoolean(true), compare14);
271+
Expression expr14Eq14 = new BinaryExpression(new LiteralInt(14), "==", new LiteralInt(14));
272+
ValDerivationNode compare14Node = new ValDerivationNode(expr14Eq14, compare14);
272273

273-
// a == 5 => true
274+
// a == 5 => 5 == 5 (unwrapped from true)
274275
ValDerivationNode val5ForCompA = new ValDerivationNode(new LiteralInt(5), new VarDerivationNode("a"));
275276
ValDerivationNode val5Literal = new ValDerivationNode(new LiteralInt(5), null);
276277
BinaryDerivationNode compareA5 = new BinaryDerivationNode(val5ForCompA, val5Literal, "==");
277-
ValDerivationNode trueFromA = new ValDerivationNode(new LiteralBoolean(true), compareA5);
278+
Expression expr5Eq5 = new BinaryExpression(new LiteralInt(5), "==", new LiteralInt(5));
279+
ValDerivationNode compare5Node = new ValDerivationNode(expr5Eq5, compareA5);
278280

279-
// b == 7 => true
281+
// b == 7 => 7 == 7 (unwrapped from true)
280282
ValDerivationNode val7ForCompB = new ValDerivationNode(new LiteralInt(7), new VarDerivationNode("b"));
281283
ValDerivationNode val7Literal = new ValDerivationNode(new LiteralInt(7), null);
282284
BinaryDerivationNode compareB7 = new BinaryDerivationNode(val7ForCompB, val7Literal, "==");
283-
ValDerivationNode trueFromB = new ValDerivationNode(new LiteralBoolean(true), compareB7);
285+
Expression expr7Eq7 = new BinaryExpression(new LiteralInt(7), "==", new LiteralInt(7));
286+
ValDerivationNode compare7Node = new ValDerivationNode(expr7Eq7, compareB7);
284287

285-
// (a == 5) && (b == 7) => true
286-
BinaryDerivationNode andAB = new BinaryDerivationNode(trueFromA, trueFromB, "&&");
287-
ValDerivationNode trueFromAB = new ValDerivationNode(new LiteralBoolean(true), andAB);
288+
// (5 == 5) && (7 == 7) (unwrapped from true)
289+
BinaryDerivationNode andAB = new BinaryDerivationNode(compare5Node, compare7Node, "&&");
290+
Expression expr5And7 = new BinaryExpression(expr5Eq5, "&&", expr7Eq7);
291+
ValDerivationNode and5And7Node = new ValDerivationNode(expr5And7, andAB);
288292

289-
// c == 14 => true
293+
// c == 14 => 14 == 14 (unwrapped from true)
290294
ValDerivationNode val14ForCompC = new ValDerivationNode(new LiteralInt(14), new VarDerivationNode("c"));
291295
ValDerivationNode val14Literal = new ValDerivationNode(new LiteralInt(14), null);
292296
BinaryDerivationNode compareC14 = new BinaryDerivationNode(val14ForCompC, val14Literal, "==");
293-
ValDerivationNode trueFromC = new ValDerivationNode(new LiteralBoolean(true), compareC14);
297+
Expression expr14Eq14b = new BinaryExpression(new LiteralInt(14), "==", new LiteralInt(14));
298+
ValDerivationNode compare14bNode = new ValDerivationNode(expr14Eq14b, compareC14);
294299

295-
// ((a == 5) && (b == 7)) && (c == 14) => true
296-
BinaryDerivationNode andABC = new BinaryDerivationNode(trueFromAB, trueFromC, "&&");
297-
ValDerivationNode trueFromAllConditions = new ValDerivationNode(new LiteralBoolean(true), andABC);
300+
// ((5 == 5) && (7 == 7)) && (14 == 14) (unwrapped from true)
301+
BinaryDerivationNode andABC = new BinaryDerivationNode(and5And7Node, compare14bNode, "&&");
302+
Expression exprConditions = new BinaryExpression(expr5And7, "&&", expr14Eq14b);
303+
ValDerivationNode conditionsNode = new ValDerivationNode(exprConditions, andABC);
298304

299-
// 14 == 14 => true
300-
BinaryDerivationNode finalAnd = new BinaryDerivationNode(trueFromComparison, trueFromAllConditions, "&&");
301-
ValDerivationNode expected = new ValDerivationNode(new LiteralBoolean(true), finalAnd);
305+
// (14 == 14) && ((5 == 5 && 7 == 7) && 14 == 14)
306+
BinaryDerivationNode finalAnd = new BinaryDerivationNode(compare14Node, conditionsNode, "&&");
307+
ValDerivationNode expected = new ValDerivationNode(result.getValue(), finalAnd);
302308

303309
// Compare the derivation trees
304310
assertDerivationEquals(expected, result, "");
@@ -550,6 +556,141 @@ void testTransitive() {
550556
assertEquals("a == 1", result.getValue().toString(), "Expected result to be a == 1");
551557
}
552558

559+
@Test
560+
void testShouldNotOversimplifyToTrue() {
561+
// Given: x > 5 && x == y && y == 10
562+
// Iteration 1: resolves y == 10, substitutes y -> 10: x > 5 && x == 10
563+
// Iteration 2: resolves x == 10, substitutes x -> 10: 10 > 5 && 10 == 10 -> true
564+
// Expected: x > 5 && x == 10 (should NOT simplify to true)
565+
566+
Expression varX = new Var("x");
567+
Expression varY = new Var("y");
568+
Expression five = new LiteralInt(5);
569+
Expression ten = new LiteralInt(10);
570+
571+
Expression xGreater5 = new BinaryExpression(varX, ">", five);
572+
Expression xEqualsY = new BinaryExpression(varX, "==", varY);
573+
Expression yEquals10 = new BinaryExpression(varY, "==", ten);
574+
575+
Expression firstAnd = new BinaryExpression(xGreater5, "&&", xEqualsY);
576+
Expression fullExpression = new BinaryExpression(firstAnd, "&&", yEquals10);
577+
578+
// When
579+
ValDerivationNode result = ExpressionSimplifier.simplify(fullExpression);
580+
581+
// Then
582+
assertNotNull(result, "Result should not be null");
583+
assertFalse(result.getValue() instanceof LiteralBoolean,
584+
"Should not oversimplify to a boolean literal, but got: " + result.getValue());
585+
assertEquals("x > 5 && x == 10", result.getValue().toString(),
586+
"Should stop simplification before collapsing to true");
587+
}
588+
589+
@Test
590+
void testShouldUnwrapBooleanInEquality() {
591+
// Given: x == (1 > 0)
592+
// Without unwrapping: x == true (unhelpful - hides what "true" came from)
593+
// Expected: x == 1 > 0 (unwrapped to show the original comparison)
594+
595+
Expression varX = new Var("x");
596+
Expression one = new LiteralInt(1);
597+
Expression zero = new LiteralInt(0);
598+
Expression oneGreaterZero = new BinaryExpression(one, ">", zero);
599+
Expression fullExpression = new BinaryExpression(varX, "==", oneGreaterZero);
600+
601+
// When
602+
ValDerivationNode result = ExpressionSimplifier.simplify(fullExpression);
603+
604+
// Then
605+
assertNotNull(result, "Result should not be null");
606+
assertEquals("x == 1 > 0", result.getValue().toString(),
607+
"Boolean in equality should be unwrapped to show the original comparison");
608+
}
609+
610+
@Test
611+
void testShouldUnwrapBooleanInEqualityWithPropagation() {
612+
// Given: x == (a > b) && a == 3 && b == 1
613+
// Without unwrapping: x == true (unhelpful)
614+
// Expected: x == 3 > 1 (unwrapped and propagated)
615+
616+
Expression varX = new Var("x");
617+
Expression varA = new Var("a");
618+
Expression varB = new Var("b");
619+
Expression aGreaterB = new BinaryExpression(varA, ">", varB);
620+
Expression xEqualsComp = new BinaryExpression(varX, "==", aGreaterB);
621+
622+
Expression three = new LiteralInt(3);
623+
Expression aEquals3 = new BinaryExpression(varA, "==", three);
624+
Expression one = new LiteralInt(1);
625+
Expression bEquals1 = new BinaryExpression(varB, "==", one);
626+
627+
Expression conditions = new BinaryExpression(aEquals3, "&&", bEquals1);
628+
Expression fullExpression = new BinaryExpression(xEqualsComp, "&&", conditions);
629+
630+
// When
631+
ValDerivationNode result = ExpressionSimplifier.simplify(fullExpression);
632+
633+
// Then
634+
assertNotNull(result, "Result should not be null");
635+
assertEquals("x == 3 > 1", result.getValue().toString(),
636+
"Boolean in equality should be unwrapped after propagation");
637+
}
638+
639+
@Test
640+
void testShouldNotUnwrapBooleanWithBooleanChildren() {
641+
// Given: (y || true) && !true && y == false
642+
// Expected: false (both children of the fold are boolean, so no unwrapping needed)
643+
644+
Expression varY = new Var("y");
645+
Expression trueExp = new LiteralBoolean(true);
646+
Expression yOrTrue = new BinaryExpression(varY, "||", trueExp);
647+
Expression notTrue = new UnaryExpression("!", trueExp);
648+
Expression falseExp = new LiteralBoolean(false);
649+
Expression yEqualsFalse = new BinaryExpression(varY, "==", falseExp);
650+
651+
Expression firstAnd = new BinaryExpression(yOrTrue, "&&", notTrue);
652+
Expression fullExpression = new BinaryExpression(firstAnd, "&&", yEqualsFalse);
653+
654+
// When
655+
ValDerivationNode result = ExpressionSimplifier.simplify(fullExpression);
656+
657+
// Then: false stays as false since both sides in the derivation are booleans
658+
assertNotNull(result, "Result should not be null");
659+
assertInstanceOf(LiteralBoolean.class, result.getValue(), "Result should remain a boolean");
660+
assertFalse(result.getValue().isBooleanTrue(), "Expected result to be false");
661+
}
662+
663+
@Test
664+
void testShouldUnwrapNestedBooleanInEquality() {
665+
// Given: x == (a + b > 10) && a == 3 && b == 5
666+
// Without unwrapping: x == true (unhelpful)
667+
// Expected: x == 8 > 10 (shows the actual comparison that produced the boolean)
668+
669+
Expression varX = new Var("x");
670+
Expression varA = new Var("a");
671+
Expression varB = new Var("b");
672+
Expression aPlusB = new BinaryExpression(varA, "+", varB);
673+
Expression ten = new LiteralInt(10);
674+
Expression comparison = new BinaryExpression(aPlusB, ">", ten);
675+
Expression xEqualsComp = new BinaryExpression(varX, "==", comparison);
676+
677+
Expression three = new LiteralInt(3);
678+
Expression aEquals3 = new BinaryExpression(varA, "==", three);
679+
Expression five = new LiteralInt(5);
680+
Expression bEquals5 = new BinaryExpression(varB, "==", five);
681+
682+
Expression conditions = new BinaryExpression(aEquals3, "&&", bEquals5);
683+
Expression fullExpression = new BinaryExpression(xEqualsComp, "&&", conditions);
684+
685+
// When
686+
ValDerivationNode result = ExpressionSimplifier.simplify(fullExpression);
687+
688+
// Then
689+
assertNotNull(result, "Result should not be null");
690+
assertEquals("x == 8 > 10", result.getValue().toString(),
691+
"Boolean in equality should be unwrapped to show the computed comparison");
692+
}
693+
553694
/**
554695
* Helper method to compare two derivation nodes recursively
555696
*/

0 commit comments

Comments
 (0)