Skip to content

Commit dcc97b0

Browse files
authored
Added decoding for static references (#1091)
Added decoding for static references: - Changed `reduce_to` to keep the static items - Decoding of statics added to `decoding.py` - Added passing test for static references, and failing test for static pointers (to be followed up on in another PR)
1 parent 122030f commit dcc97b0

11 files changed

Lines changed: 3112 additions & 17 deletions

kmir/src/kmir/decoding.py

Lines changed: 28 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@
4747

4848
from pyk.kast import KInner
4949

50+
from .alloc import DefId
5051
from .ty import FieldsShape, LayoutShape, MachineSize, Scalar, TagEncoding, Ty, TypeMetadata, UintTy
5152
from .value import MetadataSize
5253

@@ -62,7 +63,12 @@ def to_kast(self) -> KInner:
6263
)
6364

6465

65-
def decode_alloc_or_unable(alloc_info: AllocInfo, types: Mapping[Ty, TypeMetadata]) -> Value:
66+
def decode_alloc_or_unable(
67+
alloc_info: AllocInfo,
68+
types: Mapping[Ty, TypeMetadata],
69+
statics: Mapping[DefId, Allocation] | None = None,
70+
) -> Value:
71+
statics = statics or {}
6672
match alloc_info:
6773
case AllocInfo(
6874
ty=ty,
@@ -78,12 +84,28 @@ def decode_alloc_or_unable(alloc_info: AllocInfo, types: Mapping[Ty, TypeMetadat
7884
data = bytes(n or 0 for n in bytez)
7985
return _decode_memory_alloc_or_unable(data=data, ptrs=ptrs, ty=ty, types=types)
8086
case AllocInfo(
81-
ty=_,
82-
# `Static` currently only carries `def_id`; we ignore it here.
83-
global_alloc=Static(),
87+
ty=ty,
88+
global_alloc=Static(def_id=def_id),
8489
):
85-
# Static global alloc does not carry raw bytes here; leave as unable-to-decode placeholder
86-
return UnableToDecodeValue('Static global allocation not decoded')
90+
# The static's bytes live in its `MonoItemStatic` item, not in the alloc entry.
91+
# The alloc's `ty` is a reference to the static; decode its contents against the pointee type.
92+
allocation = statics.get(def_id)
93+
if allocation is None:
94+
return UnableToDecodeValue(f'Static allocation not found for def_id: {def_id}')
95+
96+
try:
97+
type_info = types[ty]
98+
except KeyError:
99+
return UnableToDecodeValue(f'Decoding static allocation with unknown type: {ty}')
100+
101+
pointee_ty = _pointee_ty(type_info)
102+
if pointee_ty is None:
103+
return UnableToDecodeValue(f'Static allocation type is not a reference or a pointer: {type_info}')
104+
105+
data = bytes(n or 0 for n in allocation.bytez)
106+
return _decode_memory_alloc_or_unable(
107+
data=data, ptrs=allocation.provenance.ptrs, ty=pointee_ty, types=types
108+
)
87109
case AllocInfo(
88110
ty=_,
89111
global_alloc=Function(

kmir/src/kmir/kompile.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -662,7 +662,7 @@ def _decode_alloc(smir_info: SMIRInfo, raw_alloc: Any) -> tuple[KInner, KInner]:
662662

663663
alloc_id = raw_alloc['alloc_id']
664664
alloc_info = smir_info.allocs[alloc_id]
665-
value = decode_alloc_or_unable(alloc_info=alloc_info, types=smir_info.types)
665+
value = decode_alloc_or_unable(alloc_info=alloc_info, types=smir_info.types, statics=smir_info.statics)
666666

667667
match value:
668668
case UnableToDecodeValue(msg):

kmir/src/kmir/smir.py

Lines changed: 24 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from functools import cached_property
77
from typing import TYPE_CHECKING, NewType
88

9-
from .alloc import AllocInfo
9+
from .alloc import Allocation, AllocInfo, DefId
1010
from .ty import EnumT, RefT, StructT, Ty, TypeMetadata, UnionT
1111

1212
if TYPE_CHECKING:
@@ -54,6 +54,16 @@ def allocs(self) -> dict[AllocId, AllocInfo]:
5454
alloc_info.alloc_id: alloc_info for alloc_info in (AllocInfo.from_dict(dct) for dct in self._smir['allocs'])
5555
}
5656

57+
@cached_property
58+
def statics(self) -> dict[DefId, Allocation]:
59+
res: dict[DefId, Allocation] = {}
60+
for item in self._smir['items']:
61+
kind = item['mono_item_kind']
62+
if 'MonoItemStatic' in kind:
63+
mono_item_static = kind['MonoItemStatic']
64+
res[DefId(mono_item_static['id'])] = Allocation.from_dict(mono_item_static['allocation'])
65+
return res
66+
5767
@cached_property
5868
def types(self) -> dict[Ty, TypeMetadata]:
5969
return {Ty(id): TypeMetadata.from_raw(type) for id, type in self._smir['types']}
@@ -131,7 +141,11 @@ def function_symbols(self) -> dict[int, dict]:
131141
fnc_symbols[-1] = {'NormalSym': self.main_symbol}
132142

133143
# function items not present in the SMIR lookup table are added with negative Ty ID
134-
missing = [name for name in self.items.keys() if {'NormalSym': name} not in fnc_symbols.values()]
144+
missing = [
145+
name
146+
for name, item in self.items.items()
147+
if SMIRInfo._is_func(item) and {'NormalSym': name} not in fnc_symbols.values()
148+
]
135149

136150
fake_ty = -2
137151
for name in missing:
@@ -181,6 +195,10 @@ def spans(self) -> dict[int, tuple[Path, int, int, int, int]]:
181195
def _is_func(item: dict[str, dict]) -> bool:
182196
return 'MonoItemFn' in item['mono_item_kind']
183197

198+
@staticmethod
199+
def _is_static(item: dict[str, dict]) -> bool:
200+
return 'MonoItemStatic' in item['mono_item_kind']
201+
184202
def reduce_to(self, start_symbols: str | Sequence[str]) -> SMIRInfo:
185203
# returns a new SMIRInfo with all _items_ removed that are not reachable from the named function(s)
186204
match start_symbols:
@@ -199,10 +217,11 @@ def reduce_to(self, start_symbols: str | Sequence[str]) -> SMIRInfo:
199217

200218
new_smir = self._smir.copy() # shallow copy, but we can overwrite the `items`
201219

202-
# filter the new symbols to avoid key errors
220+
# filter the new function symbols to avoid key errors
203221
new_syms = [self.function_symbols[ty] for ty in reachable]
204-
new_syms_ = [sym['NormalSym'] for sym in new_syms if 'NormalSym' in sym]
205-
new_smir['items'] = [self.items[sym] for sym in new_syms_ if sym in self.items]
222+
new_syms_ = {sym['NormalSym'] for sym in new_syms if 'NormalSym' in sym}
223+
# Also keep the statics
224+
new_smir['items'] = [item for sym, item in self.items.items() if SMIRInfo._is_static(item) or sym in new_syms_]
206225

207226
return SMIRInfo(new_smir)
208227

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
static NUM: u8 = 55;
2+
3+
fn main() {
4+
let num_ref = #
5+
assert!(*num_ref == 55);
6+
}

0 commit comments

Comments
 (0)