@@ -354,9 +354,11 @@ impl Expr {
354354/// Predicates
355355impl Expr {
356356 /// Replace all occurrences of `Symbol { identifier: old_id }` with `replacement`.
357- /// Produces a new expression tree with all substitutions applied.
358- /// Used for quantifier codegen to inline function bodies as pure expressions.
359- pub fn substitute_symbol ( self , old_id : & InternedString , replacement : & Expr ) -> Expr {
357+ /// Returns `(new_expr, changed)` where `changed` indicates if any substitution occurred.
358+ ///
359+ /// Note: Does NOT recurse into `StatementExpression` nodes. These must be
360+ /// flattened first via `inline_as_pure_expr` before substitution.
361+ pub fn substitute_symbol ( self , old_id : & InternedString , replacement : & Expr ) -> ( Expr , bool ) {
360362 let loc = self . location ;
361363 let typ = self . typ . clone ( ) ;
362364 let ann = self . size_of_annotation . clone ( ) ;
@@ -367,44 +369,119 @@ impl Expr {
367369 size_of_annotation : ann. clone ( ) ,
368370 } ;
369371 let sub = |e : Expr | e. substitute_symbol ( old_id, replacement) ;
370- let sub_vec = |v : Vec < Expr > | v. into_iter ( ) . map ( & sub) . collect ( ) ;
372+ let sub_vec = |v : Vec < Expr > | -> ( Vec < Expr > , bool ) {
373+ let mut changed = false ;
374+ let v: Vec < _ > = v
375+ . into_iter ( )
376+ . map ( |e| {
377+ let ( e, c) = sub ( e) ;
378+ changed |= c;
379+ e
380+ } )
381+ . collect ( ) ;
382+ ( v, changed)
383+ } ;
371384
372385 match * self . value {
373386 ExprValue :: Symbol { identifier } if identifier == * old_id => {
374- replacement. clone ( ) . with_location ( loc)
387+ ( replacement. clone ( ) . with_location ( loc) , true )
388+ }
389+ ExprValue :: AddressOf ( e) => {
390+ let ( e, c) = sub ( e) ;
391+ ( mk ( AddressOf ( e) ) , c)
392+ }
393+ ExprValue :: Dereference ( e) => {
394+ let ( e, c) = sub ( e) ;
395+ ( mk ( Dereference ( e) ) , c)
396+ }
397+ ExprValue :: Typecast ( e) => {
398+ let ( e, c) = sub ( e) ;
399+ ( mk ( Typecast ( e) ) , c)
400+ }
401+ ExprValue :: UnOp { op, e } => {
402+ let ( e, c) = sub ( e) ;
403+ ( mk ( UnOp { op, e } ) , c)
404+ }
405+ ExprValue :: BinOp { op, lhs, rhs } => {
406+ let ( l, c1) = sub ( lhs) ;
407+ let ( r, c2) = sub ( rhs) ;
408+ ( mk ( BinOp { op, lhs : l, rhs : r } ) , c1 || c2)
409+ }
410+ ExprValue :: If { c, t, e } => {
411+ let ( c, c1) = sub ( c) ;
412+ let ( t, c2) = sub ( t) ;
413+ let ( e, c3) = sub ( e) ;
414+ ( mk ( If { c, t, e } ) , c1 || c2 || c3)
415+ }
416+ ExprValue :: Index { array, index } => {
417+ let ( a, c1) = sub ( array) ;
418+ let ( i, c2) = sub ( index) ;
419+ ( mk ( Index { array : a, index : i } ) , c1 || c2)
420+ }
421+ ExprValue :: Member { lhs, field } => {
422+ let ( l, c) = sub ( lhs) ;
423+ ( mk ( Member { lhs : l, field } ) , c)
375424 }
376- ExprValue :: AddressOf ( e) => mk ( AddressOf ( sub ( e) ) ) ,
377- ExprValue :: Dereference ( e) => mk ( Dereference ( sub ( e) ) ) ,
378- ExprValue :: Typecast ( e) => mk ( Typecast ( sub ( e) ) ) ,
379- ExprValue :: UnOp { op, e } => mk ( UnOp { op, e : sub ( e) } ) ,
380- ExprValue :: BinOp { op, lhs, rhs } => mk ( BinOp { op, lhs : sub ( lhs) , rhs : sub ( rhs) } ) ,
381- ExprValue :: If { c, t, e } => mk ( If { c : sub ( c) , t : sub ( t) , e : sub ( e) } ) ,
382- ExprValue :: Index { array, index } => mk ( Index { array : sub ( array) , index : sub ( index) } ) ,
383- ExprValue :: Member { lhs, field } => mk ( Member { lhs : sub ( lhs) , field } ) ,
384425 ExprValue :: FunctionCall { function, arguments } => {
385- mk ( FunctionCall { function : sub ( function) , arguments : sub_vec ( arguments) } )
426+ let ( f, c1) = sub ( function) ;
427+ let ( a, c2) = sub_vec ( arguments) ;
428+ ( mk ( FunctionCall { function : f, arguments : a } ) , c1 || c2)
429+ }
430+ ExprValue :: Array { elems } => {
431+ let ( e, c) = sub_vec ( elems) ;
432+ ( mk ( Array { elems : e } ) , c)
433+ }
434+ ExprValue :: Struct { values } => {
435+ let ( v, c) = sub_vec ( values) ;
436+ ( mk ( Struct { values : v } ) , c)
437+ }
438+ ExprValue :: Assign { left, right } => {
439+ let ( l, c1) = sub ( left) ;
440+ let ( r, c2) = sub ( right) ;
441+ ( mk ( Assign { left : l, right : r } ) , c1 || c2)
442+ }
443+ ExprValue :: ReadOk { ptr, size } => {
444+ let ( p, c1) = sub ( ptr) ;
445+ let ( s, c2) = sub ( size) ;
446+ ( mk ( ReadOk { ptr : p, size : s } ) , c1 || c2)
447+ }
448+ ExprValue :: ArrayOf { elem } => {
449+ let ( e, c) = sub ( elem) ;
450+ ( mk ( ArrayOf { elem : e } ) , c)
451+ }
452+ ExprValue :: ByteExtract { e, offset } => {
453+ let ( e, c) = sub ( e) ;
454+ ( mk ( ByteExtract { e, offset } ) , c)
455+ }
456+ ExprValue :: SelfOp { op, e } => {
457+ let ( e, c) = sub ( e) ;
458+ ( mk ( SelfOp { op, e } ) , c)
459+ }
460+ ExprValue :: Union { value, field } => {
461+ let ( v, c) = sub ( value) ;
462+ ( mk ( Union { value : v, field } ) , c)
386463 }
387- ExprValue :: Array { elems } => mk ( Array { elems : sub_vec ( elems) } ) ,
388- ExprValue :: Struct { values } => mk ( Struct { values : sub_vec ( values) } ) ,
389- ExprValue :: Assign { left, right } => mk ( Assign { left : sub ( left) , right : sub ( right) } ) ,
390- ExprValue :: ReadOk { ptr, size } => mk ( ReadOk { ptr : sub ( ptr) , size : sub ( size) } ) ,
391- ExprValue :: ArrayOf { elem } => mk ( ArrayOf { elem : sub ( elem) } ) ,
392- ExprValue :: ByteExtract { e, offset } => mk ( ByteExtract { e : sub ( e) , offset } ) ,
393- ExprValue :: SelfOp { op, e } => mk ( SelfOp { op, e : sub ( e) } ) ,
394- ExprValue :: Union { value, field } => mk ( Union { value : sub ( value) , field } ) ,
395464 ExprValue :: Forall { variable, domain } => {
396- mk ( Forall { variable : sub ( variable) , domain : sub ( domain) } )
465+ let ( v, c1) = sub ( variable) ;
466+ let ( d, c2) = sub ( domain) ;
467+ ( mk ( Forall { variable : v, domain : d } ) , c1 || c2)
397468 }
398469 ExprValue :: Exists { variable, domain } => {
399- mk ( Exists { variable : sub ( variable) , domain : sub ( domain) } )
470+ let ( v, c1) = sub ( variable) ;
471+ let ( d, c2) = sub ( domain) ;
472+ ( mk ( Exists { variable : v, domain : d } ) , c1 || c2)
400473 }
401- ExprValue :: Vector { elems } => mk ( Vector { elems : sub_vec ( elems) } ) ,
402- ExprValue :: ShuffleVector { vector1, vector2, indexes } => mk ( ShuffleVector {
403- vector1 : sub ( vector1) ,
404- vector2 : sub ( vector2) ,
405- indexes : sub_vec ( indexes) ,
406- } ) ,
407- // Leaf nodes and statement expressions — no substitution
474+ ExprValue :: Vector { elems } => {
475+ let ( e, c) = sub_vec ( elems) ;
476+ ( mk ( Vector { elems : e } ) , c)
477+ }
478+ ExprValue :: ShuffleVector { vector1, vector2, indexes } => {
479+ let ( v1, c1) = sub ( vector1) ;
480+ let ( v2, c2) = sub ( vector2) ;
481+ let ( ix, c3) = sub_vec ( indexes) ;
482+ ( mk ( ShuffleVector { vector1 : v1, vector2 : v2, indexes : ix } ) , c1 || c2 || c3)
483+ }
484+ // Leaf nodes — no substitution possible
408485 ExprValue :: Symbol { .. }
409486 | ExprValue :: IntConstant ( _)
410487 | ExprValue :: BoolConstant ( _)
@@ -416,8 +493,10 @@ impl Expr {
416493 | ExprValue :: PointerConstant ( _)
417494 | ExprValue :: StringConstant { .. }
418495 | ExprValue :: Nondet
419- | ExprValue :: EmptyUnion
420- | ExprValue :: StatementExpression { .. } => self ,
496+ | ExprValue :: EmptyUnion => ( self , false ) ,
497+ // StatementExpression: not recursed into — must be flattened via
498+ // inline_as_pure_expr before substitution.
499+ ExprValue :: StatementExpression { .. } => ( self , false ) ,
421500 }
422501 }
423502
@@ -1847,18 +1926,16 @@ mod tests {
18471926 fn substitute_symbol_leaf_match ( ) {
18481927 let old: InternedString = "x" . into ( ) ;
18491928 let replacement = int ( 42 ) ;
1850- let result = sym ( "x" ) . substitute_symbol ( & old, & replacement) ;
1929+ let ( result, _changed ) = sym ( "x" ) . substitute_symbol ( & old, & replacement) ;
18511930 assert ! ( matches!( result. value( ) , ExprValue :: IntConstant ( v) if * v == 42 . into( ) ) ) ;
18521931 }
18531932
18541933 #[ test]
18551934 fn substitute_symbol_leaf_no_match ( ) {
18561935 let old: InternedString = "x" . into ( ) ;
18571936 let replacement = int ( 42 ) ;
1858- let result = sym ( "y" ) . substitute_symbol ( & old, & replacement) ;
1859- assert ! (
1860- matches!( result. value( ) , ExprValue :: Symbol { identifier } if identifier. to_string( ) == "y" )
1861- ) ;
1937+ let ( result, _changed) = sym ( "y" ) . substitute_symbol ( & old, & replacement) ;
1938+ assert ! ( matches!( result. value( ) , ExprValue :: Symbol { identifier } if * identifier == "y" ) ) ;
18621939 }
18631940
18641941 #[ test]
@@ -1867,7 +1944,7 @@ mod tests {
18671944 let replacement = int ( 10 ) ;
18681945 // x + 1 → 10 + 1
18691946 let expr = sym ( "x" ) . plus ( int ( 1 ) ) ;
1870- let result = expr. substitute_symbol ( & old, & replacement) ;
1947+ let ( result, _changed ) = expr. substitute_symbol ( & old, & replacement) ;
18711948 if let ExprValue :: BinOp { lhs, rhs, .. } = result. value ( ) {
18721949 assert ! ( matches!( lhs. value( ) , ExprValue :: IntConstant ( v) if * v == 10 . into( ) ) ) ;
18731950 assert ! ( matches!( rhs. value( ) , ExprValue :: IntConstant ( v) if * v == 1 . into( ) ) ) ;
@@ -1882,7 +1959,7 @@ mod tests {
18821959 let replacement = int ( 5 ) ;
18831960 // (x + x) * 2 → (5 + 5) * 2
18841961 let expr = sym ( "x" ) . plus ( sym ( "x" ) ) . mul ( int ( 2 ) ) ;
1885- let result = expr. substitute_symbol ( & old, & replacement) ;
1962+ let ( result, _changed ) = expr. substitute_symbol ( & old, & replacement) ;
18861963 if let ExprValue :: BinOp { lhs, .. } = result. value ( ) {
18871964 if let ExprValue :: BinOp { lhs : ll, rhs : lr, .. } = lhs. value ( ) {
18881965 assert ! ( matches!( ll. value( ) , ExprValue :: IntConstant ( v) if * v == 5 . into( ) ) ) ;
@@ -1900,7 +1977,7 @@ mod tests {
19001977 let old: InternedString = "x" . into ( ) ;
19011978 let replacement = int ( 7 ) ;
19021979 let expr = sym ( "x" ) . cast_to ( Type :: signed_int ( 64 ) ) ;
1903- let result = expr. substitute_symbol ( & old, & replacement) ;
1980+ let ( result, _changed ) = expr. substitute_symbol ( & old, & replacement) ;
19041981 if let ExprValue :: Typecast ( inner) = result. value ( ) {
19051982 assert ! ( matches!( inner. value( ) , ExprValue :: IntConstant ( v) if * v == 7 . into( ) ) ) ;
19061983 } else {
@@ -1914,11 +1991,9 @@ mod tests {
19141991 let replacement = int ( 1 ) ;
19151992 // y + x → y + 1
19161993 let expr = sym ( "y" ) . plus ( sym ( "x" ) ) ;
1917- let result = expr. substitute_symbol ( & old, & replacement) ;
1994+ let ( result, _changed ) = expr. substitute_symbol ( & old, & replacement) ;
19181995 if let ExprValue :: BinOp { lhs, rhs, .. } = result. value ( ) {
1919- assert ! (
1920- matches!( lhs. value( ) , ExprValue :: Symbol { identifier } if identifier. to_string( ) == "y" )
1921- ) ;
1996+ assert ! ( matches!( lhs. value( ) , ExprValue :: Symbol { identifier } if * identifier == "y" ) ) ;
19221997 assert ! ( matches!( rhs. value( ) , ExprValue :: IntConstant ( v) if * v == 1 . into( ) ) ) ;
19231998 } else {
19241999 panic ! ( "Expected BinOp" ) ;
0 commit comments