Skip to content

Commit c1b27da

Browse files
committed
fix(symbolic-proofs): preserve list invariants and harden stats
1 parent 10f2d9c commit c1b27da

9 files changed

Lines changed: 202 additions & 61 deletions

File tree

kmir/src/kmir/_prove.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,6 @@
3030

3131
_LOGGER: Final = logging.getLogger(__name__)
3232

33-
3433
def prove(opts: ProveOpts) -> APRProof:
3534
if not opts.rs_file.is_file():
3635
raise ValueError(f'Input file does not exist: {opts.rs_file}')
@@ -128,7 +127,6 @@ def _prove(opts: ProveOpts, target_path: Path, label: str, *, allow_rpc_recovery
128127
break_every_step=opts.break_every_step,
129128
break_on_function=opts.break_on_function,
130129
)
131-
132130
try:
133131
if opts.max_workers and opts.max_workers > 1:
134132
_prove_parallel(kmir, proof, opts=opts, label=label, cut_point_rules=cut_point_rules)

kmir/src/kmir/kast.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -380,7 +380,11 @@ def _symbolic_value(self, ty: Ty, mutable: bool) -> tuple[KInner, Iterable[KInne
380380
mlEqualsTrue(leInt(variant_var, token(max_variant))),
381381
]
382382
args = self._fresh_var('ENUM_ARGS')
383-
return KApply('Value::Aggregate', (KApply('variantIdx', (variant_var,)), args)), idx_range, None
383+
return (
384+
KApply('Value::Aggregate', (KApply('variantIdx', (variant_var,)), args)),
385+
idx_range + [mlEqualsTrue(KApply('allValues', (args,)))],
386+
None,
387+
)
384388

385389
case StructT(_, _, fields):
386390
field_vars: list[KInner] = []
@@ -397,14 +401,18 @@ def _symbolic_value(self, ty: Ty, mutable: bool) -> tuple[KInner, Iterable[KInne
397401

398402
case UnionT():
399403
args = self._fresh_var('ARG_UNION')
400-
return KApply('Value::Aggregate', (KApply('variantIdx', (token(0),)), args)), [], None
404+
return (
405+
KApply('Value::Aggregate', (KApply('variantIdx', (token(0),)), args)),
406+
[mlEqualsTrue(KApply('allValues', (args,)))],
407+
None,
408+
)
401409

402410
case ArrayT(_, None):
403411
elems = self._fresh_var('ARG_ARRAY')
404412
l = self._fresh_var('ARG_ARRAY_LEN')
405413
return (
406414
KApply('Value::Range', (elems,)),
407-
[mlEqualsTrue(eqInt(KApply('sizeList', (elems,)), l))],
415+
[mlEqualsTrue(eqInt(KApply('sizeList', (elems,)), l)), mlEqualsTrue(KApply('allValues', (elems,)))],
408416
KApply(
409417
'Metadata',
410418
(

kmir/src/kmir/kdist/mir-semantics/kmir-ast.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,5 +33,7 @@ module KMIR-AST
3333
3434
syntax TypeMappings ::= List{TypeMapping, ""} [group(mir-list), symbol(TypeMappings::append), terminator-symbol(TypeMappings::empty)]
3535
36+
syntax Bool ::= allValues ( List ) [function, total, symbol(allValues)]
37+
3638
endmodule
3739
```

kmir/src/kmir/kdist/mir-semantics/lemmas/kmir-lemmas.md

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ module KMIR-LEMMAS
1616
imports INT-SYMBOLIC
1717
imports BOOL
1818
19+
imports KMIR-AST
1920
imports RT-DATA
2021
```
2122
## Simplifications for lists to avoid spurious branching on error cases in control flow
@@ -33,6 +34,39 @@ The lists used in the semantics are cons-lists, so only rules with a head elemen
3334
[simplification, symbolic(REST)]
3435
3536
rule 0 <=Int size(_LIST:List) => true [simplification]
37+
38+
// `#reserveSlots` grows `ownedSlots` and `<slotStore>` in lockstep. These simplifications
39+
// let `frameLocal` peel away irrelevant tail updates when reading an older local, and
40+
// directly return the newly-added local when the read reaches the matching tail slot.
41+
rule frameLocal(_STORE[SLOT <- LOCAL], SLOTS ListItem(SLOT), size(SLOTS)) => LOCAL
42+
requires isTypedLocal(LOCAL)
43+
[simplification]
44+
45+
rule frameLocal(STORE[SLOT <- _], SLOTS ListItem(SLOT), IDX) => frameLocal(STORE, SLOTS, IDX)
46+
requires 0 <=Int IDX andBool IDX <Int size(SLOTS)
47+
[simplification]
48+
49+
// --------------------------------------------------
50+
rule allValues(.List) => true
51+
rule allValues(ListItem(_:Value) REST) => allValues(REST)
52+
rule allValues(ListItem(_) _REST) => false [owise]
53+
54+
// Symbolic prove-rs inputs use fresh `List` variables to stand for arrays, slices,
55+
// and aggregate argument lists whose elements are still runtime `Value`s. Carrying
56+
// that invariant explicitly lets reads and writes avoid spurious branches on the
57+
// underlying builtin `List:get` / `List:set` definedness checks.
58+
rule isValue(ELEMS[IDX])
59+
=> true
60+
requires allValues(ELEMS)
61+
andBool 0 <=Int IDX
62+
andBool IDX <Int size(ELEMS)
63+
[simplification, symbolic(ELEMS)]
64+
65+
rule #Ceil(ELEMS[IDX <- _VAL:Value])
66+
=> #Ceil(ELEMS)
67+
#And {true #Equals allValues(ELEMS)}
68+
#And {true #Equals 0 <=Int IDX andBool IDX <Int size(ELEMS)}
69+
[simplification, symbolic(ELEMS)]
3670
```
3771

3872
The hooked `range` function selects a segment from a list, by removing elements from front and back.

kmir/src/kmir/kdist/mir-semantics/rt/data.md

Lines changed: 15 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -60,17 +60,6 @@ More often than not, a slot or list element must be selected by index and is req
6060
requires 0 <=Int IDX andBool IDX <Int size(SLOTS)
6161
[preserves-definedness]
6262
63-
// Fresh callee slots are appended to the frame list and written into the store in lockstep.
64-
// These simplifications let current-frame reads reduce through the most recent updates even
65-
// when the slot ids themselves are symbolic fresh values.
66-
rule frameLocal(_STORE[SLOT <- LOCAL], SLOTS ListItem(SLOT), size(SLOTS)) => LOCAL
67-
requires isTypedLocal(LOCAL)
68-
[simplification]
69-
70-
rule frameLocal(STORE[SLOT <- _], SLOTS ListItem(SLOT), IDX) => frameLocal(STORE, SLOTS, IDX)
71-
requires 0 <=Int IDX andBool IDX <Int size(SLOTS)
72-
[simplification]
73-
7463
// indexing values out of TypedValue, runtime slots, and Value lists
7564
syntax Value ::= getSlotValue ( Map, Int ) [function]
7665
| frameValue ( Map, List, Int ) [function]
@@ -241,13 +230,23 @@ If we are setting a value at a `Place` which has `Projection`s in it, then we mu
241230
andBool isNewLocal(getSlot(STORE, SLOT))
242231
[preserves-definedness] // valid lookup checked
243232
244-
rule <k> #setLocalValue(place(local(I), .ProjectionElems), VAL:Value)
245-
=> #setSlotValue(#frameSlotId(SLOTS, I), VAL)
246-
...
247-
</k>
233+
rule <k> #setLocalValue(place(local(I), .ProjectionElems), VAL:Value) => .K ... </k>
234+
<currentFrame> <ownedSlots> SLOTS </ownedSlots> ... </currentFrame>
235+
<slotStore>
236+
STORE => STORE[#frameSlotId(SLOTS, I) <- typedValue(VAL, tyOfLocal(frameLocal(STORE, SLOTS, I)), mutabilityOf(frameLocal(STORE, SLOTS, I)))]
237+
</slotStore>
238+
requires 0 <=Int I andBool I <Int size(SLOTS)
239+
andBool isTypedValue(frameLocal(STORE, SLOTS, I))
240+
[preserves-definedness] // valid slot indexing and lookup checked
241+
242+
rule <k> #setLocalValue(place(local(I), .ProjectionElems), VAL:Value) => .K ... </k>
248243
<currentFrame> <ownedSlots> SLOTS </ownedSlots> ... </currentFrame>
244+
<slotStore>
245+
STORE => STORE[#frameSlotId(SLOTS, I) <- typedValue(VAL, tyOfLocal(frameLocal(STORE, SLOTS, I)), mutabilityOf(frameLocal(STORE, SLOTS, I)))]
246+
</slotStore>
249247
requires 0 <=Int I andBool I <Int size(SLOTS)
250-
[preserves-definedness] // valid slot indexing checked
248+
andBool isNewLocal(frameLocal(STORE, SLOTS, I))
249+
[preserves-definedness] // valid slot indexing and lookup checked
251250
252251
rule <k> #setLocalValue(place(local(I), PROJ), VAL:Value)
253252
=> #traverseProjection(toSlot(#frameSlotId(SLOTS, I)), frameValue(STORE, SLOTS, I), PROJ, .Contexts)

kmir/src/kmir/utils.py

Lines changed: 48 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from __future__ import annotations
22

3+
import heapq
34
import re
45
from pathlib import Path
56
from typing import TYPE_CHECKING, Sequence
@@ -177,53 +178,62 @@ def classify(node_id: int) -> str:
177178
reachable_leaf_count = 0
178179
leaf_lines: list[str] = []
179180

180-
def _path_nodes(source_id: int, path: Sequence[KCFG.Successor]) -> list[int]:
181+
def _successor_edges(source_id: int) -> list[tuple[int, int]]:
181182
from pyk.kcfg.kcfg import KCFG as _KCFG
182183

183-
node_ids = [source_id]
184-
current = source_id
185-
for succ in path:
186-
target_id: int | None = None
187-
if isinstance(succ, _KCFG.EdgeLike):
188-
target_id = succ.target.id
189-
elif isinstance(succ, _KCFG.MultiEdge):
190-
targets = list(succ.targets)
191-
if len(targets) == 1:
192-
target_id = targets[0].id
193-
if target_id is not None and target_id != current:
194-
node_ids.append(target_id)
195-
current = target_id
196-
return node_ids
197-
198-
for leaf in sorted(leaves, key=lambda n: n.id):
199-
paths = kcfg.paths_between(proof.init, leaf.id)
200-
if not paths:
201-
leaf_lines.append(f' leaf {leaf.id}: unreachable from init')
202-
continue
184+
edges: list[tuple[int, int]] = []
185+
for succ in kcfg.successors(source_id):
186+
match succ:
187+
case _KCFG.Edge(target=target, depth=depth):
188+
edges.append((target.id, depth))
189+
case _KCFG.MergedEdge(target=target, edges=merged_edges):
190+
edges.append((target.id, min(edge.depth for edge in merged_edges)))
191+
case _KCFG.Cover(target=target):
192+
edges.append((target.id, 0))
193+
case _KCFG.Split(targets=targets):
194+
edges.extend((target.id, 0) for target in targets)
195+
case _KCFG.NDBranch(targets=targets):
196+
edges.extend((target.id, 1) for target in targets)
197+
case _:
198+
raise ValueError(f'Cannot handle Successor type: {type(succ)}')
199+
return edges
203200

204-
path_infos: list[tuple[int, tuple[int, ...]]] = []
205-
seen_sequences: set[tuple[int, ...]] = set()
201+
shortest_steps: dict[int, int] = {proof.init: 0}
202+
shortest_prev: dict[int, int] = {}
203+
worklist: list[tuple[int, int]] = [(0, proof.init)]
206204

207-
for path in paths:
208-
steps = kcfg.path_length(path)
209-
node_seq = tuple(_path_nodes(proof.init, path))
210-
if node_seq in seen_sequences:
211-
continue
212-
seen_sequences.add(node_seq)
213-
path_infos.append((steps, node_seq))
205+
while worklist:
206+
curr_steps, node_id = heapq.heappop(worklist)
207+
if curr_steps != shortest_steps.get(node_id):
208+
continue
209+
for target_id, weight in sorted(_successor_edges(node_id)):
210+
next_steps = curr_steps + weight
211+
prev_steps = shortest_steps.get(target_id)
212+
# Keep the first equal-cost predecessor. Rewriting predecessors on
213+
# ties can create zero-cost cycles through Cover/Split edges and
214+
# make path reconstruction loop forever.
215+
if prev_steps is None or next_steps < prev_steps:
216+
shortest_steps[target_id] = next_steps
217+
shortest_prev[target_id] = node_id
218+
heapq.heappush(worklist, (next_steps, target_id))
219+
220+
def _shortest_path_nodes(target_id: int) -> list[int]:
221+
node_ids = [target_id]
222+
while node_ids[-1] != proof.init:
223+
node_ids.append(shortest_prev[node_ids[-1]])
224+
node_ids.reverse()
225+
return node_ids
214226

215-
if not path_infos:
227+
for leaf in sorted(leaves, key=lambda n: n.id):
228+
min_steps = shortest_steps.get(leaf.id)
229+
if min_steps is None:
216230
leaf_lines.append(f' leaf {leaf.id}: unreachable from init')
217231
continue
218232

219-
total_steps += min(steps for steps, _ in path_infos)
233+
total_steps += min_steps
220234
reachable_leaf_count += 1
221-
path_infos.sort(key=lambda info: (info[0], info[1]))
222-
223-
for idx, (steps, node_seq) in enumerate(path_infos, start=1):
224-
suffix = '' if len(path_infos) == 1 else f' (path {idx}/{len(path_infos)})'
225-
seq_str = ' -> '.join(str(nid) for nid in node_seq)
226-
leaf_lines.append(f' leaf {leaf.id}{suffix}: steps {steps}, path {seq_str}')
235+
seq_str = ' -> '.join(str(nid) for nid in _shortest_path_nodes(leaf.id))
236+
leaf_lines.append(f' leaf {leaf.id}: shortest steps {min_steps}, path {seq_str}')
227237

228238
lines.append(f' total leaves (non-root): {len(leaves)}')
229239
lines.append(f' reachable leaves : {reachable_leaf_count}')

kmir/src/tests/integration/test_cli.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
# they don't differ between local checkouts and CI (e.g. symbolic-args-fail.main.cli-stats-leaves).
2323
_REPO_ROOT = str(Path(__file__).resolve().parents[4])
2424
_PATH_REPLACEMENTS: dict[str, str] = {_REPO_ROOT + '/': '<REPO>/'}
25+
_SNAPSHOT_PROVE_MAX_DEPTH = 50
2526

2627

2728
def _prove_and_store(
@@ -32,7 +33,8 @@ def _prove_and_store(
3233
is_smir: bool = False,
3334
max_depth: int | None = None,
3435
) -> APRProof:
35-
opts = ProveOpts(rs_or_json, proof_dir=tmp_path, smir=is_smir, start_symbol=start_symbol, max_depth=max_depth)
36+
proof_max_depth = _SNAPSHOT_PROVE_MAX_DEPTH if max_depth is None else max_depth
37+
opts = ProveOpts(rs_or_json, proof_dir=tmp_path, smir=is_smir, start_symbol=start_symbol, max_depth=proof_max_depth)
3638
apr_proof = kmir.prove_program(opts)
3739
apr_proof.write_proof_data()
3840
return apr_proof

kmir/src/tests/integration/test_integration.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,11 @@
7373
'ptr-cast-array-to-singleton-wrapped-array-fail',
7474
]
7575

76+
PROVE_EXPECTED_FAILURES = {
77+
('symbolic-args-fail', 'eats_all_args'): False,
78+
}
79+
SNAPSHOT_PROVE_MAX_DEPTH = 50
80+
7681

7782
@pytest.mark.parametrize(
7883
'rs_file',
@@ -87,7 +92,7 @@ def test_prove(rs_file: Path, kmir: KMIR, update_expected_output: bool) -> None:
8792
if update_expected_output and not should_show:
8893
pytest.skip()
8994

90-
prove_opts = ProveOpts(rs_file, smir=is_smir, terminate_on_thunk=True)
95+
prove_opts = ProveOpts(rs_file, smir=is_smir, terminate_on_thunk=True, max_depth=SNAPSHOT_PROVE_MAX_DEPTH)
9196
printer = PrettyPrinter(kmir.definition)
9297
cterm_show = CTermShow(printer.print)
9398

@@ -98,6 +103,7 @@ def test_prove(rs_file: Path, kmir: KMIR, update_expected_output: bool) -> None:
98103
for start_symbol in start_symbols:
99104
prove_opts.start_symbol = start_symbol
100105
apr_proof = kmir.prove_program(prove_opts)
106+
should_fail = PROVE_EXPECTED_FAILURES.get((rs_file.stem, start_symbol), rs_file.stem.endswith('fail'))
101107

102108
if should_show:
103109
display_opts = ShowOpts(

kmir/src/tests/unit/test_utils.py

Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,82 @@
1+
from __future__ import annotations
2+
3+
from dataclasses import dataclass
4+
5+
from pyk.cterm import CSubst, CTerm
6+
from pyk.kast.inner import KApply
7+
from pyk.kcfg.kcfg import KCFG
8+
9+
from kmir.utils import render_statistics
10+
11+
12+
@dataclass
13+
class _FakeKCFG:
14+
nodes: tuple[KCFG.Node, ...]
15+
leaves: tuple[KCFG.Node, ...]
16+
root_ids: frozenset[int]
17+
successor_map: dict[int, tuple[object, ...]]
18+
19+
def is_root(self, node_id: int) -> bool:
20+
return node_id in self.root_ids
21+
22+
def successors(self, node_id: int) -> tuple[object, ...]:
23+
return self.successor_map.get(node_id, ())
24+
25+
def is_split(self, _node_id: int) -> bool:
26+
return False
27+
28+
def is_ndbranch(self, _node_id: int) -> bool:
29+
return False
30+
31+
def is_stuck(self, _node_id: int) -> bool:
32+
return False
33+
34+
35+
@dataclass
36+
class _FakeProof:
37+
kcfg: _FakeKCFG
38+
init: int
39+
pending_ids: frozenset[int]
40+
41+
def is_target(self, _node_id: int) -> bool:
42+
return False
43+
44+
def is_terminal(self, _node_id: int) -> bool:
45+
return False
46+
47+
def is_refuted(self, _node_id: int) -> bool:
48+
return False
49+
50+
def is_bounded(self, _node_id: int) -> bool:
51+
return False
52+
53+
def is_pending(self, node_id: int) -> bool:
54+
return node_id in self.pending_ids
55+
56+
def is_failing(self, _node_id: int) -> bool:
57+
return False
58+
59+
60+
def test_render_statistics_handles_zero_cost_predecessor_cycles() -> None:
61+
kcfg = KCFG()
62+
loop_target = kcfg.create_node(CTerm(KApply('<loopTarget>')))
63+
init = kcfg.create_node(CTerm(KApply('<init>')))
64+
leaf = kcfg.create_node(CTerm(KApply('<leaf>')))
65+
66+
fake_kcfg = _FakeKCFG(
67+
nodes=(loop_target, init, leaf),
68+
leaves=(leaf,),
69+
root_ids=frozenset({init.id}),
70+
successor_map={
71+
init.id: (KCFG.Cover(init, loop_target, CSubst()),),
72+
loop_target.id: (
73+
KCFG.Cover(loop_target, init, CSubst()),
74+
KCFG.Edge(loop_target, leaf, 1, ()),
75+
),
76+
},
77+
)
78+
proof = _FakeProof(fake_kcfg, init=init.id, pending_ids=frozenset({leaf.id}))
79+
80+
lines = render_statistics(proof)
81+
82+
assert f' leaf {leaf.id}: shortest steps 1, path {init.id} -> {loop_target.id} -> {leaf.id}' in lines

0 commit comments

Comments
 (0)