Skip to content

Commit 8b2600a

Browse files
committed
Improve robustness of pure expression inliner
- substitute_symbol returns (Expr, bool) for reliable change detection instead of Debug-format string comparison - Graceful fallback for recursive functions (tracing::warn + return original) instead of assert! panic - Diagnostics for edge cases: multiple assignments to same variable, non-symbol return expressions, unknown UnaryOperator variants - Public API cleanup: inline_as_pure_expr_toplevel wrapper hides the visited set implementation detail - Document StatementExpression non-recursion in substitute_symbol Signed-off-by: Felipe R. Monteiro <felisous@amazon.com>
1 parent 45648b6 commit 8b2600a

File tree

3 files changed

+197
-107
lines changed

3 files changed

+197
-107
lines changed

cprover_bindings/src/goto_program/expr.rs

Lines changed: 120 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -354,9 +354,11 @@ impl Expr {
354354
/// Predicates
355355
impl Expr {
356356
/// Replace all occurrences of `Symbol { identifier: old_id }` with `replacement`.
357-
/// Produces a new expression tree with all substitutions applied.
358-
/// Used for quantifier codegen to inline function bodies as pure expressions.
359-
pub fn substitute_symbol(self, old_id: &InternedString, replacement: &Expr) -> Expr {
357+
/// Returns `(new_expr, changed)` where `changed` indicates if any substitution occurred.
358+
///
359+
/// Note: Does NOT recurse into `StatementExpression` nodes. These must be
360+
/// flattened first via `inline_as_pure_expr` before substitution.
361+
pub fn substitute_symbol(self, old_id: &InternedString, replacement: &Expr) -> (Expr, bool) {
360362
let loc = self.location;
361363
let typ = self.typ.clone();
362364
let ann = self.size_of_annotation.clone();
@@ -367,44 +369,119 @@ impl Expr {
367369
size_of_annotation: ann.clone(),
368370
};
369371
let sub = |e: Expr| e.substitute_symbol(old_id, replacement);
370-
let sub_vec = |v: Vec<Expr>| v.into_iter().map(&sub).collect();
372+
let sub_vec = |v: Vec<Expr>| -> (Vec<Expr>, bool) {
373+
let mut changed = false;
374+
let v: Vec<_> = v
375+
.into_iter()
376+
.map(|e| {
377+
let (e, c) = sub(e);
378+
changed |= c;
379+
e
380+
})
381+
.collect();
382+
(v, changed)
383+
};
371384

372385
match *self.value {
373386
ExprValue::Symbol { identifier } if identifier == *old_id => {
374-
replacement.clone().with_location(loc)
387+
(replacement.clone().with_location(loc), true)
388+
}
389+
ExprValue::AddressOf(e) => {
390+
let (e, c) = sub(e);
391+
(mk(AddressOf(e)), c)
392+
}
393+
ExprValue::Dereference(e) => {
394+
let (e, c) = sub(e);
395+
(mk(Dereference(e)), c)
396+
}
397+
ExprValue::Typecast(e) => {
398+
let (e, c) = sub(e);
399+
(mk(Typecast(e)), c)
400+
}
401+
ExprValue::UnOp { op, e } => {
402+
let (e, c) = sub(e);
403+
(mk(UnOp { op, e }), c)
404+
}
405+
ExprValue::BinOp { op, lhs, rhs } => {
406+
let (l, c1) = sub(lhs);
407+
let (r, c2) = sub(rhs);
408+
(mk(BinOp { op, lhs: l, rhs: r }), c1 || c2)
409+
}
410+
ExprValue::If { c, t, e } => {
411+
let (c, c1) = sub(c);
412+
let (t, c2) = sub(t);
413+
let (e, c3) = sub(e);
414+
(mk(If { c, t, e }), c1 || c2 || c3)
415+
}
416+
ExprValue::Index { array, index } => {
417+
let (a, c1) = sub(array);
418+
let (i, c2) = sub(index);
419+
(mk(Index { array: a, index: i }), c1 || c2)
420+
}
421+
ExprValue::Member { lhs, field } => {
422+
let (l, c) = sub(lhs);
423+
(mk(Member { lhs: l, field }), c)
375424
}
376-
ExprValue::AddressOf(e) => mk(AddressOf(sub(e))),
377-
ExprValue::Dereference(e) => mk(Dereference(sub(e))),
378-
ExprValue::Typecast(e) => mk(Typecast(sub(e))),
379-
ExprValue::UnOp { op, e } => mk(UnOp { op, e: sub(e) }),
380-
ExprValue::BinOp { op, lhs, rhs } => mk(BinOp { op, lhs: sub(lhs), rhs: sub(rhs) }),
381-
ExprValue::If { c, t, e } => mk(If { c: sub(c), t: sub(t), e: sub(e) }),
382-
ExprValue::Index { array, index } => mk(Index { array: sub(array), index: sub(index) }),
383-
ExprValue::Member { lhs, field } => mk(Member { lhs: sub(lhs), field }),
384425
ExprValue::FunctionCall { function, arguments } => {
385-
mk(FunctionCall { function: sub(function), arguments: sub_vec(arguments) })
426+
let (f, c1) = sub(function);
427+
let (a, c2) = sub_vec(arguments);
428+
(mk(FunctionCall { function: f, arguments: a }), c1 || c2)
429+
}
430+
ExprValue::Array { elems } => {
431+
let (e, c) = sub_vec(elems);
432+
(mk(Array { elems: e }), c)
433+
}
434+
ExprValue::Struct { values } => {
435+
let (v, c) = sub_vec(values);
436+
(mk(Struct { values: v }), c)
437+
}
438+
ExprValue::Assign { left, right } => {
439+
let (l, c1) = sub(left);
440+
let (r, c2) = sub(right);
441+
(mk(Assign { left: l, right: r }), c1 || c2)
442+
}
443+
ExprValue::ReadOk { ptr, size } => {
444+
let (p, c1) = sub(ptr);
445+
let (s, c2) = sub(size);
446+
(mk(ReadOk { ptr: p, size: s }), c1 || c2)
447+
}
448+
ExprValue::ArrayOf { elem } => {
449+
let (e, c) = sub(elem);
450+
(mk(ArrayOf { elem: e }), c)
451+
}
452+
ExprValue::ByteExtract { e, offset } => {
453+
let (e, c) = sub(e);
454+
(mk(ByteExtract { e, offset }), c)
455+
}
456+
ExprValue::SelfOp { op, e } => {
457+
let (e, c) = sub(e);
458+
(mk(SelfOp { op, e }), c)
459+
}
460+
ExprValue::Union { value, field } => {
461+
let (v, c) = sub(value);
462+
(mk(Union { value: v, field }), c)
386463
}
387-
ExprValue::Array { elems } => mk(Array { elems: sub_vec(elems) }),
388-
ExprValue::Struct { values } => mk(Struct { values: sub_vec(values) }),
389-
ExprValue::Assign { left, right } => mk(Assign { left: sub(left), right: sub(right) }),
390-
ExprValue::ReadOk { ptr, size } => mk(ReadOk { ptr: sub(ptr), size: sub(size) }),
391-
ExprValue::ArrayOf { elem } => mk(ArrayOf { elem: sub(elem) }),
392-
ExprValue::ByteExtract { e, offset } => mk(ByteExtract { e: sub(e), offset }),
393-
ExprValue::SelfOp { op, e } => mk(SelfOp { op, e: sub(e) }),
394-
ExprValue::Union { value, field } => mk(Union { value: sub(value), field }),
395464
ExprValue::Forall { variable, domain } => {
396-
mk(Forall { variable: sub(variable), domain: sub(domain) })
465+
let (v, c1) = sub(variable);
466+
let (d, c2) = sub(domain);
467+
(mk(Forall { variable: v, domain: d }), c1 || c2)
397468
}
398469
ExprValue::Exists { variable, domain } => {
399-
mk(Exists { variable: sub(variable), domain: sub(domain) })
470+
let (v, c1) = sub(variable);
471+
let (d, c2) = sub(domain);
472+
(mk(Exists { variable: v, domain: d }), c1 || c2)
400473
}
401-
ExprValue::Vector { elems } => mk(Vector { elems: sub_vec(elems) }),
402-
ExprValue::ShuffleVector { vector1, vector2, indexes } => mk(ShuffleVector {
403-
vector1: sub(vector1),
404-
vector2: sub(vector2),
405-
indexes: sub_vec(indexes),
406-
}),
407-
// Leaf nodes and statement expressions — no substitution
474+
ExprValue::Vector { elems } => {
475+
let (e, c) = sub_vec(elems);
476+
(mk(Vector { elems: e }), c)
477+
}
478+
ExprValue::ShuffleVector { vector1, vector2, indexes } => {
479+
let (v1, c1) = sub(vector1);
480+
let (v2, c2) = sub(vector2);
481+
let (ix, c3) = sub_vec(indexes);
482+
(mk(ShuffleVector { vector1: v1, vector2: v2, indexes: ix }), c1 || c2 || c3)
483+
}
484+
// Leaf nodes — no substitution possible
408485
ExprValue::Symbol { .. }
409486
| ExprValue::IntConstant(_)
410487
| ExprValue::BoolConstant(_)
@@ -416,8 +493,10 @@ impl Expr {
416493
| ExprValue::PointerConstant(_)
417494
| ExprValue::StringConstant { .. }
418495
| ExprValue::Nondet
419-
| ExprValue::EmptyUnion
420-
| ExprValue::StatementExpression { .. } => self,
496+
| ExprValue::EmptyUnion => (self, false),
497+
// StatementExpression: not recursed into — must be flattened via
498+
// inline_as_pure_expr before substitution.
499+
ExprValue::StatementExpression { .. } => (self, false),
421500
}
422501
}
423502

@@ -1847,18 +1926,16 @@ mod tests {
18471926
fn substitute_symbol_leaf_match() {
18481927
let old: InternedString = "x".into();
18491928
let replacement = int(42);
1850-
let result = sym("x").substitute_symbol(&old, &replacement);
1929+
let (result, _changed) = sym("x").substitute_symbol(&old, &replacement);
18511930
assert!(matches!(result.value(), ExprValue::IntConstant(v) if *v == 42.into()));
18521931
}
18531932

18541933
#[test]
18551934
fn substitute_symbol_leaf_no_match() {
18561935
let old: InternedString = "x".into();
18571936
let replacement = int(42);
1858-
let result = sym("y").substitute_symbol(&old, &replacement);
1859-
assert!(
1860-
matches!(result.value(), ExprValue::Symbol { identifier } if identifier.to_string() == "y")
1861-
);
1937+
let (result, _changed) = sym("y").substitute_symbol(&old, &replacement);
1938+
assert!(matches!(result.value(), ExprValue::Symbol { identifier } if *identifier == "y"));
18621939
}
18631940

18641941
#[test]
@@ -1867,7 +1944,7 @@ mod tests {
18671944
let replacement = int(10);
18681945
// x + 1 → 10 + 1
18691946
let expr = sym("x").plus(int(1));
1870-
let result = expr.substitute_symbol(&old, &replacement);
1947+
let (result, _changed) = expr.substitute_symbol(&old, &replacement);
18711948
if let ExprValue::BinOp { lhs, rhs, .. } = result.value() {
18721949
assert!(matches!(lhs.value(), ExprValue::IntConstant(v) if *v == 10.into()));
18731950
assert!(matches!(rhs.value(), ExprValue::IntConstant(v) if *v == 1.into()));
@@ -1882,7 +1959,7 @@ mod tests {
18821959
let replacement = int(5);
18831960
// (x + x) * 2 → (5 + 5) * 2
18841961
let expr = sym("x").plus(sym("x")).mul(int(2));
1885-
let result = expr.substitute_symbol(&old, &replacement);
1962+
let (result, _changed) = expr.substitute_symbol(&old, &replacement);
18861963
if let ExprValue::BinOp { lhs, .. } = result.value() {
18871964
if let ExprValue::BinOp { lhs: ll, rhs: lr, .. } = lhs.value() {
18881965
assert!(matches!(ll.value(), ExprValue::IntConstant(v) if *v == 5.into()));
@@ -1900,7 +1977,7 @@ mod tests {
19001977
let old: InternedString = "x".into();
19011978
let replacement = int(7);
19021979
let expr = sym("x").cast_to(Type::signed_int(64));
1903-
let result = expr.substitute_symbol(&old, &replacement);
1980+
let (result, _changed) = expr.substitute_symbol(&old, &replacement);
19041981
if let ExprValue::Typecast(inner) = result.value() {
19051982
assert!(matches!(inner.value(), ExprValue::IntConstant(v) if *v == 7.into()));
19061983
} else {
@@ -1914,11 +1991,9 @@ mod tests {
19141991
let replacement = int(1);
19151992
// y + x → y + 1
19161993
let expr = sym("y").plus(sym("x"));
1917-
let result = expr.substitute_symbol(&old, &replacement);
1994+
let (result, _changed) = expr.substitute_symbol(&old, &replacement);
19181995
if let ExprValue::BinOp { lhs, rhs, .. } = result.value() {
1919-
assert!(
1920-
matches!(lhs.value(), ExprValue::Symbol { identifier } if identifier.to_string() == "y")
1921-
);
1996+
assert!(matches!(lhs.value(), ExprValue::Symbol { identifier } if *identifier == "y"));
19221997
assert!(matches!(rhs.value(), ExprValue::IntConstant(v) if *v == 1.into()));
19231998
} else {
19241999
panic!("Expected BinOp");

docs/dev/pure-expression-inliner.md

Lines changed: 24 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -2,72 +2,49 @@
22

33
## Overview
44

5-
The pure expression inliner (`inline_as_pure_expr`) is a function call inlining
6-
mechanism that produces side-effect-free expression trees. Unlike the original
5+
The pure expression inliner (`inline_as_pure_expr`) inlines function calls
6+
within expression trees as side-effect-free expressions. Unlike the existing
77
`inline_function_calls_in_expr` which wraps inlined bodies in CBMC
8-
`StatementExpression` nodes, this produces expressions using only pure
9-
constructs: `BinOp`, `UnOp`, `If` (ternary), `Typecast`, etc.
10-
11-
## Motivation
12-
13-
CBMC's quantifier expressions (`forall`, `exists`) reject side effects in their
14-
bodies. The original inliner produced `StatementExpression` nodes which CBMC
15-
treats as side effects, causing invariant violations. The pure inliner eliminates
16-
this by producing expression trees that CBMC can process directly.
17-
18-
## How It Works
19-
20-
For a function call `f(arg1, arg2)` where `f` is defined as:
21-
```c
22-
ret_type f(param1, param2) {
23-
local1 = expr1(param1);
24-
local2 = expr2(local1, param2);
25-
return local2;
26-
}
27-
```
28-
29-
The pure inliner:
30-
1. Collects all assignments: `{local1 → expr1(param1), local2 → expr2(local1, param2)}`
31-
2. Finds the return symbol: `local2`
32-
3. Resolves intermediates: `local2` → `expr2(local1, param2)` → `expr2(expr1(param1), param2)`
33-
4. Substitutes parameters: `expr2(expr1(arg1), arg2)`
34-
5. Flattens `StatementExpression` nodes (e.g., checked arithmetic → just the operation)
35-
6. Recursively inlines any remaining function calls
8+
`StatementExpression` nodes, this produces pure expression trees.
369

3710
## Soundness Implications
3811

3912
**Checked arithmetic in quantifier bodies**: When flattening `StatementExpression`
4013
nodes (e.g., from checked division or remainder), the pure inliner drops the
4114
`Assert` and `Assume` statements that check for overflow and division by zero.
42-
This means:
4315

44-
- **Division by zero** inside a quantifier body will NOT be detected. For example,
45-
`forall!(|i in (0, 10)| arr[i] / x == 0)` where `x` could be zero will not
46-
produce a division-by-zero check.
16+
- **Division by zero** inside a quantifier body will NOT be detected.
4717
- **Arithmetic overflow** inside a quantifier body will NOT be detected.
4818

49-
This is a known trade-off: CBMC requires pure expressions in quantifier bodies,
50-
and runtime checks are inherently side effects. Users should ensure that
51-
arithmetic operations in quantifier predicates cannot overflow or divide by zero.
52-
5319
**Future improvement**: The dropped assertions could be hoisted outside the
54-
quantifier as preconditions, preserving soundness while keeping the quantifier
55-
body pure.
20+
quantifier as preconditions, preserving soundness while keeping the body pure.
5621

5722
## Limitations
5823

5924
- **No control flow**: Functions with `if`/`else` or `match` that produce
6025
multiple assignments to the return variable are not fully supported. The
61-
inliner takes the last assignment, which may not be correct for all paths.
26+
inliner takes the last assignment and emits a `tracing::debug!` diagnostic.
6227
- **No loops**: Functions containing loops cannot be inlined as pure expressions.
63-
- **No recursion**: Recursive functions are detected and cause a panic.
64-
- **Checked arithmetic**: Overflow/division-by-zero checks (`Assert` + `Assume`
65-
statements) are dropped when flattening `StatementExpression` nodes. This
66-
means the pure expression doesn't include these runtime checks.
28+
- **No recursion**: Recursive functions are detected and the original expression
29+
is returned unchanged (with a `tracing::warn!` diagnostic). No ICE.
30+
- **StatementExpression in substitute_symbol**: `Expr::substitute_symbol` does
31+
NOT recurse into `StatementExpression` nodes. These must be flattened via
32+
`inline_as_pure_expr` before substitution.
33+
34+
## API
35+
36+
```rust
37+
// Public entry point — manages the visited set internally
38+
pub fn inline_as_pure_expr_toplevel(&self, expr: &Expr) -> Expr;
39+
40+
// Expr method — returns (new_expr, changed) for reliable change detection
41+
pub fn substitute_symbol(self, old_id: &InternedString, replacement: &Expr) -> (Expr, bool);
42+
```
6743

6844
## Files
6945

7046
- `cprover_bindings/src/goto_program/expr.rs``Expr::substitute_symbol()`
7147
- `kani-compiler/src/codegen_cprover_gotoc/context/goto_ctx.rs``inline_as_pure_expr()`,
72-
`inline_call_as_pure_expr()`, `collect_assignments_from_stmt()`,
73-
`find_return_symbol_in_stmt()`, `resolve_intermediates_iterative()`
48+
`inline_as_pure_expr_toplevel()`, `inline_call_as_pure_expr()`,
49+
`collect_assignments_from_stmt()`, `find_return_symbol_in_stmt()`,
50+
`resolve_intermediates_iterative()`

0 commit comments

Comments
 (0)