Skip to content

Commit 443f9ad

Browse files
sskeirikStevengreclaude
authored
Teach kmir to reduce SMIR K definitions over a set of CFG roots (#845)
Co-authored-by: Stevengre <zhaojianhong96@gmail.com> Co-authored-by: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent 9af0261 commit 443f9ad

6 files changed

Lines changed: 110 additions & 9 deletions

File tree

kmir/src/kmir/__main__.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
LinkOpts,
2424
ProveOpts,
2525
PruneOpts,
26+
ReduceOpts,
2627
RunOpts,
2728
SectionEdgeOpts,
2829
ShowOpts,
@@ -251,6 +252,17 @@ def _kmir_link(opts: LinkOpts) -> None:
251252
result.dump(opts.output_file)
252253

253254

255+
def _kmir_reduce(opts: ReduceOpts) -> None:
256+
smir_info = SMIRInfo.from_file(opts.smir_file)
257+
original = len(smir_info.items)
258+
reduced = smir_info.reduce_to(opts.roots)
259+
reduced.dump(opts.output_file)
260+
_LOGGER.info(
261+
f'Reduced {original} -> {len(reduced.items)} items'
262+
f' ({original - len(reduced.items)} pruned), written to {opts.output_file}'
263+
)
264+
265+
254266
def kmir(args: Sequence[str]) -> None:
255267
ns = _arg_parser().parse_args(args)
256268
opts = _parse_args(ns)
@@ -272,6 +284,8 @@ def kmir(args: Sequence[str]) -> None:
272284
_kmir_prove(opts)
273285
case LinkOpts():
274286
_kmir_link(opts)
287+
case ReduceOpts():
288+
_kmir_reduce(opts)
275289
case _:
276290
raise AssertionError()
277291

@@ -575,6 +589,27 @@ def _arg_parser() -> ArgumentParser:
575589
default='linker_output.smir.json',
576590
)
577591

592+
reduce_parser = command_parser.add_parser(
593+
'reduce',
594+
help='Reduce SMIR to functions reachable from given roots',
595+
parents=[kcli_args.logging_args],
596+
)
597+
reduce_parser.add_argument('smir_file', metavar='SMIR_JSON', help='SMIR JSON file to reduce')
598+
reduce_parser.add_argument(
599+
'--roots',
600+
'-r',
601+
required=True,
602+
metavar='ROOTS',
603+
help='Comma-separated root function names, or @file for newline-separated file',
604+
)
605+
reduce_parser.add_argument(
606+
'--output-file',
607+
'-o',
608+
metavar='FILE',
609+
help='Output file (default: reduced.smir.json)',
610+
default='reduced.smir.json',
611+
)
612+
578613
return parser
579614

580615

@@ -677,6 +712,12 @@ def _parse_args(ns: Namespace) -> KMirOpts:
677712
smir_files=ns.smir_files,
678713
output_file=ns.output_file,
679714
)
715+
case 'reduce':
716+
return ReduceOpts(
717+
smir_file=ns.smir_file,
718+
roots=ns.roots,
719+
output_file=ns.output_file,
720+
)
680721
case _:
681722
raise AssertionError()
682723

kmir/src/kmir/_prove.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,6 @@ def _prove(opts: ProveOpts, target_path: Path, label: str) -> APRProof:
8282
if 'MonoItemFn' in item['mono_item_kind'] and item['mono_item_kind']['MonoItemFn'].get('body') is None
8383
]
8484
has_missing = len(missing_body_syms) > 0
85-
_LOGGER.info(f'Reduced items table size {len(smir_info.items)}')
8685
if has_missing:
8786
_LOGGER.info(f'missing-bodies-present={has_missing} count={len(missing_body_syms)}')
8887
_LOGGER.debug(f'Missing-body function symbols (first 5): {missing_body_syms[:5]}')

kmir/src/kmir/options.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -323,3 +323,20 @@ class LinkOpts(KMirOpts):
323323
def __init__(self, smir_files: list[str], output_file: str | None = None) -> None:
324324
self.smir_files = [Path(f) for f in smir_files]
325325
self.output_file = Path(output_file) if output_file is not None else Path('linker_output.smir.json')
326+
327+
328+
@dataclass
329+
class ReduceOpts(KMirOpts):
330+
smir_file: Path
331+
output_file: Path
332+
roots: list[str]
333+
334+
def __init__(self, smir_file: str, roots: str, output_file: str | None = None) -> None:
335+
self.smir_file = Path(smir_file)
336+
self.output_file = Path(output_file) if output_file is not None else Path('reduced.smir.json')
337+
# Support @file syntax for reading roots from a file
338+
if roots.startswith('@'):
339+
roots_file = Path(roots[1:])
340+
self.roots = list(filter(None, [r.strip() for r in roots_file.read_text().splitlines()]))
341+
else:
342+
self.roots = [r.strip() for r in roots.split(',') if r.strip()]

kmir/src/kmir/smir.py

Lines changed: 14 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from .ty import EnumT, RefT, StructT, Ty, TypeMetadata, UnionT
1111

1212
if TYPE_CHECKING:
13+
from collections.abc import Sequence
1314
from pathlib import Path
1415
from typing import Final
1516

@@ -180,13 +181,19 @@ def spans(self) -> dict[int, tuple[Path, int, int, int, int]]:
180181
def _is_func(item: dict[str, dict]) -> bool:
181182
return 'MonoItemFn' in item['mono_item_kind']
182183

183-
def reduce_to(self, start_name: str) -> SMIRInfo:
184-
# returns a new SMIRInfo with all _items_ removed that are not reachable from the named function
185-
start_ty = self.function_tys[start_name]
184+
def reduce_to(self, start_symbols: str | Sequence[str]) -> SMIRInfo:
185+
# returns a new SMIRInfo with all _items_ removed that are not reachable from the named function(s)
186+
match start_symbols:
187+
case str(symbol):
188+
start_tys = [Ty(self.function_tys[symbol])]
189+
case [*symbols] if symbols and all(isinstance(sym, str) for sym in symbols):
190+
start_tys = [Ty(self.function_tys[sym]) for sym in symbols]
191+
case _:
192+
raise ValueError('SMIRInfo.reduce_to() received an invalid start_symbol')
186193

187-
_LOGGER.debug(f'Reducing items, starting at {start_ty}. Call Edges {self.call_edges}')
194+
_LOGGER.debug(f'Reducing items, starting at {start_tys}. Call Edges {self.call_edges}')
188195

189-
reachable = compute_closure(Ty(start_ty), self.call_edges)
196+
reachable = compute_closure(start_tys, self.call_edges)
190197

191198
_LOGGER.debug(f'Reducing to reachable Tys {reachable}')
192199

@@ -249,8 +256,8 @@ def call_edges(self) -> dict[Ty, set[Ty]]:
249256
return result
250257

251258

252-
def compute_closure(start: Ty, edges: dict[Ty, set[Ty]]) -> set[Ty]:
253-
work = deque([start])
259+
def compute_closure(start_nodes: Sequence[Ty], edges: dict[Ty, set[Ty]]) -> set[Ty]:
260+
work = deque(start_nodes)
254261
reached = set()
255262
finished = False
256263
while not finished:

kmir/src/tests/integration/test_integration.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -656,3 +656,40 @@ def test_schema_kapply_parse(
656656
json_data, expected_term, expected_sort = test_case
657657

658658
assert parser.parse_mir_json(json_data, expected_sort.name) == (expected_term, expected_sort)
659+
660+
661+
ARITH_SMIR = PROVE_DIR / 'arith.smir.json'
662+
663+
664+
def test_reduce_standalone() -> None:
665+
"""Test that kmir reduce correctly prunes SMIR items by reachability."""
666+
smir_data = json.loads(ARITH_SMIR.read_text())
667+
info = SMIRInfo(smir_data)
668+
assert len(info.items) == 11
669+
670+
# Single root 'add' — should keep 1 item
671+
reduced_add = info.reduce_to('add')
672+
assert len(reduced_add.items) == 1
673+
674+
# Single root 'mul' — should keep 1 item (independent from add)
675+
reduced_mul = info.reduce_to('mul')
676+
assert len(reduced_mul.items) == 1
677+
678+
# Multiple roots — should keep strictly more than either alone
679+
reduced_multi = info.reduce_to(['add', 'mul'])
680+
assert len(reduced_multi.items) == 2
681+
682+
# 'main' calls both add and mul — should keep all 3
683+
reduced_main = info.reduce_to('main')
684+
assert len(reduced_main.items) == 3
685+
686+
# Roundtrip: save reduced SMIR and reload it
687+
with tempfile.NamedTemporaryFile(suffix='.smir.json', delete=False, mode='w') as f:
688+
f.write(json.dumps(reduced_multi._smir))
689+
reduced_path = Path(f.name)
690+
691+
try:
692+
reloaded = SMIRInfo(json.loads(reduced_path.read_text()))
693+
assert len(reloaded.items) == 2
694+
finally:
695+
reduced_path.unlink()

kmir/src/tests/unit/test_unit.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,5 +59,5 @@
5959
def test_compute_closure(test_case: tuple[str, Ty, dict[Ty, set[Ty]], list[Ty]]) -> None:
6060
_, start, edges, expected = test_case
6161

62-
result = compute_closure(start, edges)
62+
result = compute_closure([start], edges)
6363
assert result == expected

0 commit comments

Comments
 (0)