Skip to content

Commit 2b6ce41

Browse files
authored
kmir prove takes multiple start-symbols (#1071)
This PR adds the ability to provide multiple start symbols to `kmir prove`, and importantly, a single call to `llvm-kompile` for the proofs. In particular: - Names can be provided to `--start-symbol(s)` either comma separated or repeated `kmir prove --start-symbols name1, name2 --start-symbol name3`; - Function `_prove_multi` will reduce the items to the provided names and do a single `llvm-kompile` to be shared; - Each proof is run _sequentially_ and the definition and smir.json is stored in a directory `filename_stem.kompiled`, and the proof dirs just store the kcfg (this is only if multiple start symbols are provided); - `prove_with_kmir` is not removed. It was considered but the way it interacts with the test runner is nicer so I left it; - Test is added to `test_cli`; - Observed to correctly reduce the calls to `llvm-kompile` and has greatly increased performance for mutliple start symbol proofs;
1 parent 2d70882 commit 2b6ce41

6 files changed

Lines changed: 169 additions & 74 deletions

File tree

kmir/src/kmir/__main__.py

Lines changed: 23 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,13 @@
4343
_LOG_FORMAT: Final = '%(levelname)s %(asctime)s %(name)s - %(message)s'
4444

4545

46+
def _flatten_comma_list(values: list[str] | None) -> list[str] | None:
47+
"""Flatten a list that may contain comma-separated entries, e.g. ['a,b', 'c'] -> ['a', 'b', 'c']."""
48+
if values is None:
49+
return None
50+
return [item for v in values for item in v.split(',') if item.strip()]
51+
52+
4653
def _flush_lines(lines: Iterable[str]) -> None:
4754
"""Print lines to stdout one at a time, then release the list if possible."""
4855
for line in lines:
@@ -82,9 +89,13 @@ def run(target_dir: Path):
8289

8390

8491
def _kmir_prove(opts: ProveOpts) -> None:
85-
proof = KMIR.prove_program(opts)
86-
print(str(proof.summary))
87-
if not proof.passed:
92+
proofs = KMIR.prove_programs(opts)
93+
any_failed = False
94+
for proof in proofs:
95+
print(str(proof.summary))
96+
if not proof.passed:
97+
any_failed = True
98+
if any_failed:
8899
sys.exit(1)
89100

90101

@@ -587,7 +598,14 @@ def _arg_parser() -> ArgumentParser:
587598
)
588599
prove_parser.add_argument('--smir', action='store_true', help='Treat the input file as a smir json.')
589600
prove_parser.add_argument(
590-
'--start-symbol', type=str, metavar='SYMBOL', default='main', help='Symbol name to begin execution from'
601+
'--start-symbol',
602+
'--start-symbols',
603+
dest='start_symbols',
604+
type=str,
605+
metavar='SYMBOL',
606+
action='append',
607+
default=None,
608+
help='Symbol name(s) to prove (repeatable, comma-separated allowed)',
591609
)
592610
prove_parser.add_argument(
593611
'--add-module',
@@ -710,7 +728,7 @@ def _parse_args(ns: Namespace) -> KMirOpts:
710728
maintenance_rate=ns.maintenance_rate,
711729
save_smir=ns.save_smir,
712730
smir=ns.smir,
713-
start_symbol=ns.start_symbol,
731+
start_symbols=_flatten_comma_list(ns.start_symbols),
714732
break_on_calls=ns.break_on_calls,
715733
break_on_function_calls=ns.break_on_function_calls,
716734
break_on_intrinsic_calls=ns.break_on_intrinsic_calls,

kmir/src/kmir/_prove.py

Lines changed: 78 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -31,78 +31,96 @@
3131
_LOGGER: Final = logging.getLogger(__name__)
3232

3333

34-
def prove(opts: ProveOpts) -> APRProof:
35-
"""Run a proof creating a new KMIR instance."""
34+
def prove(opts: ProveOpts) -> list[APRProof]:
35+
"""Prove one or more start symbols from the same file, kompiling only once."""
3636
if not opts.rs_file.is_file():
3737
raise ValueError(f'Input file does not exist: {opts.rs_file}')
3838

3939
if opts.max_workers is not None and opts.max_workers < 1:
4040
raise ValueError(f'Expected positive integer for `max_workers, got: {opts.max_workers}')
4141

42-
label = f'{opts.rs_file.stem}.{opts.start_symbol}'
43-
4442
if opts.proof_dir is not None:
45-
target_path = opts.proof_dir / label
46-
return _prove(opts, target_path, label)
43+
if len(opts.start_symbols) == 1:
44+
label = f'{opts.rs_file.stem}.{opts.start_symbols[0]}'
45+
target_path = opts.proof_dir / label
46+
else:
47+
# Multiple start symbols share a separate target dir
48+
target_path = opts.proof_dir / f'{opts.rs_file.stem}.kompiled'
49+
return _prove_multi(opts, target_path)
4750

4851
with tempfile.TemporaryDirectory() as tmp_dir:
4952
target_path = Path(tmp_dir)
50-
return _prove(opts, target_path, label)
53+
return _prove_multi(opts, target_path)
5154

5255

5356
def prove_with_kmir(
5457
kmir: KMIR,
5558
smir_info: SMIRInfo,
5659
opts: ProveOpts,
5760
) -> APRProof:
58-
"""Run a proof using a pre-built KMIR instance, avoiding redundant kompilation."""
61+
"""Prove a single symbol using a pre-built KMIR instance.
62+
63+
The intended use case for this function is internal with the test runner.
64+
Use this instead of `prove()` when the caller manages kompilation externally
65+
(e.g. the test harness kompiles once and loops over symbols). Only the first
66+
entry in `opts.start_symbols` is used; for multi-symbol proving use `prove()`
67+
which kompiles once internally but is for external use via `kmir prove`.
68+
"""
69+
assert len(opts.start_symbols) == 1, f'prove_with_kmir handles a single symbol; got {len(opts.start_symbols)}.'
5970

6071
if opts.max_workers is not None and opts.max_workers < 1:
6172
raise ValueError(f'Expected positive integer for `max_workers, got: {opts.max_workers}')
6273

6374
# No check for rs_file as smir_info already exists
64-
label = f'{opts.rs_file.stem}.{opts.start_symbol}'
75+
start_symbol = opts.start_symbols[0]
76+
label = f'{opts.rs_file.stem}.{start_symbol}'
6577

6678
_LOGGER.info(f'Using pre-built KMIR for proof: {label}')
6779
proof = apr_proof_from_smir(
6880
kmir,
6981
label,
7082
smir_info,
71-
start_symbol=opts.start_symbol,
83+
start_symbol=start_symbol,
7284
proof_dir=opts.proof_dir,
7385
)
74-
if proof.proof_dir is not None and (proof.proof_dir / label).is_dir():
75-
smir_info.dump(proof.proof_dir / proof.id / 'smir.json')
86+
if opts.proof_dir is not None:
87+
kompiled_smir_path = opts.proof_dir / f'{opts.rs_file.stem}.kompiled' / 'smir.json'
88+
kompiled_smir_path.parent.mkdir(parents=True, exist_ok=True)
89+
smir_info.dump(kompiled_smir_path)
7690

7791
return _advance_proof(kmir, proof, opts, label)
7892

7993

80-
def _prove(opts: ProveOpts, target_path: Path, label: str) -> APRProof:
81-
if not opts.reload and opts.proof_dir is not None and APRProof.proof_data_exists(label, opts.proof_dir):
82-
_LOGGER.info(f'Reading proof from disc: {opts.proof_dir}, {label}')
83-
proof = APRProof.read_proof_data(opts.proof_dir, label)
84-
85-
smir_info = SMIRInfo.from_file(target_path / 'smir.json')
86-
kmir = KMIR.from_kompiled_kore(
87-
smir_info,
88-
target_dir=target_path,
89-
extra_module=opts.add_module,
90-
bug_report=opts.bug_report,
91-
symbolic=True,
92-
haskell_target=opts.haskell_target,
93-
llvm_lib_target=opts.llvm_lib_target,
94-
break_on_function=opts.break_on_function or None,
95-
)
94+
def _prove_multi(opts: ProveOpts, target_path: Path) -> list[APRProof]:
95+
"""Prove single or multiple symbols with a single kompilation."""
96+
labels = [f'{opts.rs_file.stem}.{sym}' for sym in opts.start_symbols]
97+
98+
if not labels:
99+
raise ValueError('No label to prove')
100+
101+
# Check which proofs can be resumed
102+
resumable: dict[str, APRProof] = {}
103+
if not opts.reload and opts.proof_dir is not None:
104+
for label in labels:
105+
if APRProof.proof_data_exists(label, opts.proof_dir):
106+
_LOGGER.info(f'Reading proof from disc: {opts.proof_dir}, {label}')
107+
resumable[label] = APRProof.read_proof_data(opts.proof_dir, label)
108+
109+
# Load SMIR info (once)
110+
kompiled_smir_path = target_path / 'smir.json'
111+
if len(resumable) == len(labels):
112+
# All proofs are resumed, load SMIR from saved data
113+
smir_info = SMIRInfo.from_file(kompiled_smir_path)
96114
else:
97-
_LOGGER.info(f'Constructing initial proof: {label}')
115+
# Need fresh SMIR for at least one proof
98116
if opts.parsed_smir is not None:
99117
smir_info = SMIRInfo(opts.parsed_smir)
100118
elif opts.smir:
101119
smir_info = SMIRInfo.from_file(opts.rs_file)
102120
else:
103121
smir_info = SMIRInfo(cargo_get_smir_json(opts.rs_file, save_smir=opts.save_smir))
104122

105-
smir_info = smir_info.reduce_to(opts.start_symbol)
123+
smir_info = smir_info.reduce_to(opts.start_symbols)
106124
# Report whether the reduced call graph includes any functions without MIR bodies
107125
missing_body_syms = [
108126
sym
@@ -114,28 +132,38 @@ def _prove(opts: ProveOpts, target_path: Path, label: str) -> APRProof:
114132
_LOGGER.info(f'missing-bodies-present={has_missing} count={len(missing_body_syms)}')
115133
_LOGGER.debug(f'Missing-body function symbols (first 5): {missing_body_syms[:5]}')
116134

117-
kmir = KMIR.from_kompiled_kore(
118-
smir_info,
119-
target_dir=target_path,
120-
extra_module=opts.add_module,
121-
bug_report=opts.bug_report,
122-
symbolic=True,
123-
haskell_target=opts.haskell_target,
124-
llvm_lib_target=opts.llvm_lib_target,
125-
break_on_function=opts.break_on_function or None,
126-
)
135+
kmir = KMIR.from_kompiled_kore(
136+
smir_info,
137+
target_dir=target_path,
138+
extra_module=opts.add_module,
139+
bug_report=opts.bug_report,
140+
symbolic=True,
141+
haskell_target=opts.haskell_target,
142+
llvm_lib_target=opts.llvm_lib_target,
143+
break_on_function=opts.break_on_function or None,
144+
)
127145

128-
proof = apr_proof_from_smir(
129-
kmir,
130-
label,
131-
smir_info,
132-
start_symbol=opts.start_symbol,
133-
proof_dir=opts.proof_dir,
134-
)
135-
if proof.proof_dir is not None and (proof.proof_dir / label).is_dir():
136-
smir_info.dump(proof.proof_dir / proof.id / 'smir.json')
146+
smir_info.dump(kompiled_smir_path)
137147

138-
return _advance_proof(kmir, proof, opts, label)
148+
# Prove each symbol sequentially using shared definition
149+
results: list[APRProof] = []
150+
for label, start_symbol in zip(labels, opts.start_symbols, strict=True):
151+
if label in resumable:
152+
proof = resumable[label]
153+
else:
154+
_LOGGER.info(f'Constructing initial proof: {label}')
155+
proof = apr_proof_from_smir(
156+
kmir,
157+
label,
158+
smir_info,
159+
start_symbol=start_symbol,
160+
proof_dir=opts.proof_dir,
161+
)
162+
163+
proof = _advance_proof(kmir, proof, opts, label)
164+
results.append(proof)
165+
166+
return results
139167

140168

141169
def _advance_proof(kmir: KMIR, proof: APRProof, opts: ProveOpts, label: str) -> APRProof:

kmir/src/kmir/kmir.py

Lines changed: 21 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -121,11 +121,17 @@ def run_smir(
121121
return result
122122

123123
@staticmethod
124-
def prove_program(opts: ProveOpts) -> APRProof:
124+
def prove_programs(opts: ProveOpts) -> list[APRProof]:
125125
from ._prove import prove
126126

127127
return prove(opts)
128128

129+
@staticmethod
130+
def prove_program(opts: ProveOpts) -> APRProof:
131+
proofs = KMIR.prove_programs(opts)
132+
assert len(proofs) == 1, f'Expected single proof, got {len(proofs)}'
133+
return proofs[0]
134+
129135
@staticmethod
130136
def prove_program_with_kmir(kmir: KMIR, smir_info: SMIRInfo, opts: ProveOpts) -> APRProof:
131137
from ._prove import prove_with_kmir
@@ -171,12 +177,20 @@ def __init__(self, cterm_show: CTermShow, proof: APRProof, opts: DisplayOpts) ->
171177
self.smir_info = None
172178
if opts.smir_info:
173179
self.smir_info = SMIRInfo.from_file(opts.smir_info)
174-
elif (
175-
proof.proof_dir is not None
176-
and (proof.proof_dir / proof.id).is_dir()
177-
and (proof.proof_dir / proof.id / 'smir.json').is_file()
178-
):
179-
self.smir_info = SMIRInfo.from_file(proof.proof_dir / proof.id / 'smir.json')
180+
elif proof.proof_dir is not None:
181+
file_stem = proof.id.rsplit('.', 1)[0]
182+
# Single-symbol: smir.json in the proof's own directory
183+
label_smir = proof.proof_dir / proof.id / 'smir.json'
184+
# Multi-symbol: smir.json in the shared kompiled directory
185+
kompiled_smir = proof.proof_dir / f'{file_stem}.kompiled' / 'smir.json'
186+
if label_smir.is_file():
187+
self.smir_info = SMIRInfo.from_file(label_smir)
188+
elif kompiled_smir.is_file():
189+
self.smir_info = SMIRInfo.from_file(kompiled_smir)
190+
else:
191+
_LOGGER.warning('SMIR info not found, span/function annotations unavailable')
192+
else:
193+
_LOGGER.warning('No SMIR Info or proof dir found, span/function annotations unavailable')
180194

181195
def _span(self, node: KCFG.Node) -> str | None:
182196
curr_span: int | None = None

kmir/src/kmir/options.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,7 @@ class ProveOpts(KMirOpts):
7575
save_smir: bool
7676
smir: bool
7777
parsed_smir: dict | None
78-
start_symbol: str
78+
start_symbols: list[str]
7979
add_module: Path | str | None
8080
break_on_calls: bool
8181
break_on_function_calls: bool
@@ -111,7 +111,7 @@ def __init__(
111111
save_smir: bool = False,
112112
smir: bool = False,
113113
parsed_smir: dict | None = None,
114-
start_symbol: str = 'main',
114+
start_symbols: list[str] | None = None,
115115
break_on_calls: bool = False,
116116
break_on_function_calls: bool = False,
117117
break_on_intrinsic_calls: bool = False,
@@ -144,7 +144,12 @@ def __init__(
144144
self.save_smir = save_smir
145145
self.smir = smir
146146
self.parsed_smir = parsed_smir
147-
self.start_symbol = start_symbol
147+
148+
if start_symbols is not None:
149+
self.start_symbols = start_symbols
150+
else:
151+
self.start_symbols = ['main']
152+
148153
self.break_on_calls = break_on_calls
149154
self.break_on_function_calls = break_on_function_calls
150155
self.break_on_intrinsic_calls = break_on_intrinsic_calls

0 commit comments

Comments
 (0)