Skip to content

Commit 6c8d46e

Browse files
authored
More enum decoding (#740)
* Decode single-variant enums * Decode multi-variant enums with direct tag encoding
1 parent 5432828 commit 6c8d46e

12 files changed

Lines changed: 612 additions & 38 deletions

kmir/src/kmir/decoding.py

Lines changed: 167 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,23 @@
77
from pyk.kast.prelude.string import stringToken
88

99
from .alloc import Allocation, AllocInfo, Memory, ProvenanceEntry, ProvenanceMap
10-
from .ty import ArrayT, Bool, EnumT, Int, IntTy, PtrT, RefT, Str, Uint
10+
from .ty import (
11+
ArbitraryFields,
12+
ArrayT,
13+
BoolT,
14+
Direct,
15+
EnumT,
16+
Initialized,
17+
IntT,
18+
IntTy,
19+
Multiple,
20+
PrimitiveInt,
21+
PtrT,
22+
RefT,
23+
Single,
24+
StrT,
25+
UintT,
26+
)
1127
from .value import (
1228
NO_METADATA,
1329
AggregateValue,
@@ -26,7 +42,7 @@
2642

2743
from pyk.kast import KInner
2844

29-
from .ty import Ty, TypeMetadata, UintTy
45+
from .ty import FieldsShape, LayoutShape, MachineSize, Scalar, TagEncoding, Ty, TypeMetadata, UintTy
3046
from .value import Metadata
3147

3248

@@ -126,16 +142,26 @@ def decode_value_or_unable(data: bytes, type_info: TypeMetadata, types: Mapping[
126142

127143
def decode_value(data: bytes, type_info: TypeMetadata, types: Mapping[Ty, TypeMetadata]) -> Value:
128144
match type_info:
129-
case Bool():
145+
case BoolT():
130146
return _decode_bool(data)
131-
case Str():
147+
case StrT():
132148
return _decode_str(data)
133-
case Uint(int_ty) | Int(int_ty):
149+
case UintT(int_ty) | IntT(int_ty):
134150
return _decode_int(data, int_ty)
135151
case ArrayT(elem_ty, length):
136152
return _decode_array(data, elem_ty, length, types)
137-
case EnumT(discriminants=discriminants, fields=fields):
138-
return _decode_enum(data, discriminants, fields)
153+
case EnumT(
154+
discriminants=discriminants,
155+
fields=fields,
156+
layout=layout,
157+
):
158+
return _decode_enum(
159+
data=data,
160+
discriminants=discriminants,
161+
fields=fields,
162+
layout=layout,
163+
types=types,
164+
)
139165
case _:
140166
raise ValueError(f'Unsupported type: {type_info}')
141167

@@ -195,18 +221,145 @@ def _decode_array(
195221

196222

197223
def _decode_enum(
224+
*,
225+
data: bytes,
226+
discriminants: list[int],
227+
fields: list[list[Ty]],
228+
layout: LayoutShape | None,
229+
types: Mapping[Ty, TypeMetadata],
230+
) -> Value:
231+
if not layout:
232+
raise ValueError('Enum layout not provided')
233+
234+
offsets = _extract_offsets(layout.fields)
235+
236+
match layout.variants:
237+
case Single(index):
238+
return _decode_enum_single(
239+
data=data,
240+
discriminants=discriminants,
241+
fields=fields,
242+
offsets=offsets,
243+
# ---
244+
tag_index=index,
245+
# ---
246+
types=types,
247+
)
248+
case Multiple(
249+
tag=tag,
250+
tag_encoding=tag_encoding,
251+
tag_field=tag_field,
252+
variants=variants,
253+
):
254+
return _decode_enum_multiple(
255+
data=data,
256+
discriminants=discriminants,
257+
fields=fields,
258+
offsets=offsets,
259+
# ---
260+
tag=tag,
261+
tag_encoding=tag_encoding,
262+
tag_field=tag_field,
263+
variant_layouts=variants,
264+
# ---
265+
types=types,
266+
)
267+
case _:
268+
raise AssertionError('Undhandled case')
269+
270+
271+
def _extract_offsets(fields_shape: FieldsShape) -> list[MachineSize]:
272+
match fields_shape:
273+
case ArbitraryFields(offsets=offsets):
274+
return offsets
275+
case _:
276+
raise ValueError(f'Unsupported fields shape: {fields_shape}')
277+
278+
279+
def _decode_enum_single(
280+
*,
198281
data: bytes,
199282
discriminants: list[int],
200283
fields: list[list[Ty]],
284+
offsets: list[MachineSize],
285+
tag_index: int,
286+
types: Mapping[Ty, TypeMetadata],
201287
) -> Value:
202-
# The only supported case for now is when there are no fields
203-
if any(tys for tys in fields):
204-
raise ValueError('TODO - implement this case')
288+
assert len(fields) == 1, 'Expected a single list of field types for single-variant enum'
289+
tys = fields[0]
290+
291+
assert len(discriminants) == 1, 'Expected a single discriminant for single-variant enum'
292+
discriminant = discriminants[0]
293+
assert tag_index == discriminant, 'Assumed tag_index to be the same as the discriminant'
294+
295+
field_values = _decode_fields(data=data, tys=tys, offsets=offsets, types=types)
296+
return AggregateValue(0, field_values)
297+
298+
299+
def _decode_enum_multiple(
300+
*,
301+
data: bytes,
302+
discriminants: list[int],
303+
fields: list[list[Ty]],
304+
offsets: list[MachineSize],
305+
# ---
306+
tag: Scalar,
307+
tag_encoding: TagEncoding,
308+
tag_field: int,
309+
variant_layouts: list[LayoutShape],
310+
# ---
311+
types: Mapping[Ty, TypeMetadata],
312+
) -> Value:
313+
if not isinstance(tag_encoding, Direct):
314+
raise ValueError(f'Unsupported encoding: {tag_encoding}')
315+
316+
assert tag_field == 0, 'Assumed tag field to be zero'
317+
assert len(offsets) == 1, 'Assumed offsets to only contain the tag offset'
318+
tag_offset = offsets[0]
319+
tag_value = _extract_tag_value(data=data, tag_offset=tag_offset, tag=tag)
205320

206-
tag = int.from_bytes(data, byteorder='little', signed=False)
207321
try:
208-
variant_idx = discriminants.index(tag)
322+
variant_idx = discriminants.index(tag_value)
209323
except ValueError as err:
210-
raise ValueError(f'Tag not found: {tag}') from err
324+
raise ValueError(f'Tag not found: {tag_value}') from err
325+
326+
tys = fields[variant_idx]
327+
328+
variant_layout = variant_layouts[variant_idx]
329+
field_offsets = _extract_offsets(variant_layout.fields)
330+
assert isinstance(variant_layout.variants, Single)
211331

212-
return AggregateValue(variant_idx, ())
332+
field_values = _decode_fields(data=data, tys=tys, offsets=field_offsets, types=types)
333+
return AggregateValue(variant_idx, field_values)
334+
335+
336+
def _decode_fields(
337+
*,
338+
data: bytes,
339+
tys: list[Ty],
340+
offsets: list[MachineSize],
341+
types: Mapping[Ty, TypeMetadata],
342+
) -> list[Value]:
343+
res: list[Value] = []
344+
for ty, offset in zip(tys, offsets, strict=True):
345+
type_info = types[ty]
346+
size_in_bytes = type_info.nbytes(types)
347+
field_data = data[offset.in_bytes : offset.in_bytes + size_in_bytes]
348+
value = decode_value(field_data, type_info, types)
349+
res.append(value)
350+
return res
351+
352+
353+
def _extract_tag_value(*, data: bytes, tag_offset: MachineSize, tag: Scalar) -> int:
354+
match tag:
355+
case Initialized(
356+
value=PrimitiveInt(
357+
length=length,
358+
signed=signed,
359+
),
360+
valid_range=_,
361+
):
362+
tag_data = data[tag_offset.in_bytes : tag_offset.in_bytes + length.value]
363+
return int.from_bytes(tag_data, byteorder='little', signed=signed)
364+
case _:
365+
raise ValueError('Unsupported tag: {tag}')

kmir/src/kmir/kast.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
from pyk.kast.prelude.ml import mlEqualsTrue
1010
from pyk.kast.prelude.utils import token
1111

12-
from .ty import ArrayT, Bool, EnumT, Int, PtrT, RefT, StructT, TupleT, Uint, UnionT
12+
from .ty import ArrayT, BoolT, EnumT, IntT, PtrT, RefT, StructT, TupleT, UintT, UnionT
1313

1414
if TYPE_CHECKING:
1515
from collections.abc import Iterable
@@ -146,15 +146,15 @@ def _fresh_var(self, prefix: str) -> KVariable:
146146
def _symbolic_value(self, ty: Ty, mutable: bool) -> tuple[KInner, Iterable[KInner], KInner | None]:
147147
# returns: symbolic value of given type, related constraints, related pointer metadata
148148
match self.smir_info.types.get(ty):
149-
case Int(info):
149+
case IntT(info):
150150
val, constraints = int_var(self._fresh_var('ARG_INT'), info.value, True)
151151
return val, constraints, None
152152

153-
case Uint(info):
153+
case UintT(info):
154154
val, constraints = int_var(self._fresh_var('ARG_UINT'), info.value, False)
155155
return val, constraints, None
156156

157-
case Bool():
157+
case BoolT():
158158
val, constraints = bool_var(self._fresh_var('ARG_BOOL'))
159159
return val, constraints, None
160160

0 commit comments

Comments
 (0)