1010import liquidjava .rj_language .opt .derivation_node .ValDerivationNode ;
1111import liquidjava .rj_language .opt .derivation_node .VarDerivationNode ;
1212
13+ import java .util .HashMap ;
1314import java .util .Map ;
1415
1516public class ConstantPropagation {
@@ -19,64 +20,122 @@ public class ConstantPropagation {
1920 * VariableResolver to extract variable equalities from the expression first. Returns a derivation node representing
2021 * the propagation steps taken.
2122 */
22- public static ValDerivationNode propagate (Expression exp , DerivationNode defaultOrigin ) {
23+ public static ValDerivationNode propagate (Expression exp , ValDerivationNode previousOrigin ) {
2324 Map <String , Expression > substitutions = VariableResolver .resolve (exp );
24- return propagateRecursive (exp , substitutions , defaultOrigin );
25+
26+ // map of variable origins from the previous derivation tree
27+ Map <String , DerivationNode > varOrigins = new HashMap <>();
28+ if (previousOrigin != null ) {
29+ extractVarOrigins (previousOrigin , varOrigins );
30+ }
31+ return propagateRecursive (exp , substitutions , varOrigins );
2532 }
2633
2734 /**
2835 * Recursively performs constant propagation on an expression (e.g. x + y && x == 1 && y == 2 => 1 + 2)
2936 */
3037 private static ValDerivationNode propagateRecursive (Expression exp , Map <String , Expression > subs ,
31- DerivationNode defaultOrigin ) {
38+ Map < String , DerivationNode > varOrigins ) {
3239
3340 // substitute variable
3441 if (exp instanceof Var var ) {
3542 String name = var .getName ();
3643 Expression value = subs .get (name );
3744 // substitution
38- if (value != null )
39- return new ValDerivationNode (value .clone (), new VarDerivationNode (name ));
45+ if (value != null ) {
46+ // check if this variable has an origin from a previous pass
47+ DerivationNode previousOrigin = varOrigins .get (name );
48+
49+ // preserve origin if value came from previous derivation
50+ DerivationNode origin = previousOrigin != null ? new VarDerivationNode (name , previousOrigin ) : new VarDerivationNode (name );
51+ return new ValDerivationNode (value .clone (), origin );
52+ }
4053
4154 // no substitution
42- return new ValDerivationNode (var , defaultOrigin );
55+ return new ValDerivationNode (var , null );
4356 }
4457
4558 // lift unary origin
4659 if (exp instanceof UnaryExpression unary ) {
47- ValDerivationNode operand = propagateRecursive (unary .getChildren ().get (0 ), subs , defaultOrigin );
60+ ValDerivationNode operand = propagateRecursive (unary .getChildren ().get (0 ), subs , varOrigins );
4861 UnaryExpression cloned = (UnaryExpression ) unary .clone ();
4962 cloned .setChild (0 , operand .getValue ());
5063
51- DerivationNode origin = operand .getOrigin () != null ? new UnaryDerivationNode ( operand , cloned . getOp ())
52- : defaultOrigin ;
53- return new ValDerivationNode (cloned , origin );
64+ return operand .getOrigin () != null
65+ ? new ValDerivationNode ( cloned , new UnaryDerivationNode ( operand , cloned . getOp ()))
66+ : new ValDerivationNode (cloned , null );
5467 }
5568
5669 // lift binary origin
5770 if (exp instanceof BinaryExpression binary ) {
58- ValDerivationNode left = propagateRecursive (binary .getFirstOperand (), subs , defaultOrigin );
59- ValDerivationNode right = propagateRecursive (binary .getSecondOperand (), subs , defaultOrigin );
71+ ValDerivationNode left = propagateRecursive (binary .getFirstOperand (), subs , varOrigins );
72+ ValDerivationNode right = propagateRecursive (binary .getSecondOperand (), subs , varOrigins );
6073 BinaryExpression cloned = (BinaryExpression ) binary .clone ();
6174 cloned .setChild (0 , left .getValue ());
6275 cloned .setChild (1 , right .getValue ());
6376
64- DerivationNode origin = (left .getOrigin () != null || right .getOrigin () != null )
65- ? new BinaryDerivationNode (left , right , cloned .getOperator ()) : defaultOrigin ;
66- return new ValDerivationNode (cloned , origin );
77+ return (left .getOrigin () != null || right .getOrigin () != null )
78+ ? new ValDerivationNode ( cloned , new BinaryDerivationNode (left , right , cloned .getOperator ()))
79+ : new ValDerivationNode (cloned , null );
6780 }
6881
6982 // recursively propagate children
7083 if (exp .hasChildren ()) {
7184 Expression propagated = exp .clone ();
7285 for (int i = 0 ; i < exp .getChildren ().size (); i ++) {
73- ValDerivationNode child = propagateRecursive (exp .getChildren ().get (i ), subs , defaultOrigin );
86+ ValDerivationNode child = propagateRecursive (exp .getChildren ().get (i ), subs , varOrigins );
7487 propagated .setChild (i , child .getValue ());
7588 }
76- return new ValDerivationNode (propagated , defaultOrigin );
89+ return new ValDerivationNode (propagated , null );
7790 }
7891
7992 // no propagation
80- return new ValDerivationNode (exp , defaultOrigin );
93+ return new ValDerivationNode (exp , null );
94+ }
95+
96+
97+ /**
98+ * Extracts the derivation nodes for variable values from the derivation tree
99+ * This is so done so when we find "var == value" in the tree, we store the derivation of the value
100+ * So it can be preserved when var is substituted in subsequent passes
101+ */
102+ private static void extractVarOrigins (ValDerivationNode node , Map <String , DerivationNode > varOrigins ) {
103+ if (node == null )
104+ return ;
105+
106+ Expression value = node .getValue ();
107+ DerivationNode origin = node .getOrigin ();
108+
109+ // check for equality expressions
110+ if (value instanceof BinaryExpression binExp && "==" .equals (binExp .getOperator ())
111+ && origin instanceof BinaryDerivationNode binOrigin ) {
112+ Expression left = binExp .getFirstOperand ();
113+ Expression right = binExp .getSecondOperand ();
114+
115+ // extract variable name and value derivation from either side
116+ String varName = null ;
117+ ValDerivationNode valueDerivation = null ;
118+
119+ if (left instanceof Var var && right .isLiteral ()) {
120+ varName = var .getName ();
121+ valueDerivation = binOrigin .getRight ();
122+ } else if (right instanceof Var var && left .isLiteral ()) {
123+ varName = var .getName ();
124+ valueDerivation = binOrigin .getLeft ();
125+ }
126+ if (varName != null && valueDerivation != null && valueDerivation .getOrigin () != null ) {
127+ varOrigins .put (varName , valueDerivation .getOrigin ());
128+ }
129+ }
130+
131+ // recursively process the origin tree
132+ if (origin instanceof BinaryDerivationNode binOrigin ) {
133+ extractVarOrigins (binOrigin .getLeft (), varOrigins );
134+ extractVarOrigins (binOrigin .getRight (), varOrigins );
135+ } else if (origin instanceof UnaryDerivationNode unaryOrigin ) {
136+ extractVarOrigins (unaryOrigin .getOperand (), varOrigins );
137+ } else if (origin instanceof ValDerivationNode valOrigin ) {
138+ extractVarOrigins (valOrigin , varOrigins );
139+ }
81140 }
82141}
0 commit comments