Skip to content

Commit b8e6a93

Browse files
authored
Implement decoding for structs (#762)
Closes #761
1 parent 161c1a6 commit b8e6a93

5 files changed

Lines changed: 131 additions & 9 deletions

File tree

kmir/src/kmir/decoding.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
RefT,
2222
Single,
2323
StrT,
24+
StructT,
2425
UintT,
2526
)
2627
from .value import (
@@ -149,6 +150,8 @@ def decode_value(data: bytes, type_info: TypeMetadata, types: Mapping[Ty, TypeMe
149150
return _decode_int(data, int_ty)
150151
case ArrayT(elem_ty, length):
151152
return _decode_array(data, elem_ty, length, types)
153+
case StructT(fields=fields, layout=layout):
154+
return _decode_struct(data=data, fields=fields, layout=layout, types=types)
152155
case EnumT(
153156
discriminants=discriminants,
154157
fields=fields,
@@ -219,6 +222,28 @@ def _decode_array(
219222
return RangeValue(elems)
220223

221224

225+
def _decode_struct(
226+
*,
227+
data: bytes,
228+
fields: list[Ty],
229+
layout: LayoutShape | None,
230+
types: Mapping[Ty, TypeMetadata],
231+
) -> Value:
232+
if not layout:
233+
raise ValueError('Struct layout not provided')
234+
235+
offsets = _extract_offsets(layout.fields)
236+
237+
match layout.variants:
238+
case Single(index=0):
239+
pass
240+
case _:
241+
raise ValueError(f'Unexpected layout variants in struct: {layout.variants}')
242+
243+
field_values = _decode_fields(data=data, tys=fields, offsets=offsets, types=types)
244+
return AggregateValue(0, field_values)
245+
246+
222247
def _decode_enum(
223248
*,
224249
data: bytes,

kmir/src/kmir/ty.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -146,7 +146,7 @@ def from_raw(data: Any) -> EnumT:
146146
adt_def=adt_def,
147147
discriminants=list(discriminants),
148148
fields=[list(tys) for tys in fields],
149-
layout=LayoutShape.from_raw(layout),
149+
layout=LayoutShape.from_raw(layout) if layout is not None else None,
150150
)
151151
case _:
152152
raise _cannot_parse_as('EnumT', data)
@@ -467,6 +467,7 @@ class StructT(TypeMetadata):
467467
name: str
468468
adt_def: int
469469
fields: list[Ty]
470+
layout: LayoutShape | None
470471

471472
@staticmethod
472473
def from_raw(data: Any) -> StructT:
@@ -476,16 +477,25 @@ def from_raw(data: Any) -> StructT:
476477
'name': name,
477478
'adt_def': adt_def,
478479
'fields': fields,
480+
'layout': layout,
479481
}
480482
}:
481483
return StructT(
482484
name=name,
483485
adt_def=adt_def,
484486
fields=list(fields),
487+
layout=LayoutShape.from_raw(layout) if layout is not None else None,
485488
)
486489
case _:
487490
raise _cannot_parse_as('StructT', data)
488491

492+
def nbytes(self, types: Mapping[Ty, TypeMetadata]) -> int:
493+
match self.layout:
494+
case None:
495+
raise ValueError(f'Cannot determine size, layout is missing for: {self}')
496+
case LayoutShape(size=size):
497+
return size.in_bytes
498+
489499

490500
@dataclass
491501
class UnionT(TypeMetadata):
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
Aggregate ( variantIdx ( 0 ) , ListItem ( Integer ( 15 , 8 , false ) )
2+
ListItem ( Integer ( 31 , 8 , false ) )
3+
ListItem ( BoolVal ( true ) ) )
Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
{
2+
"bytes": [
3+
1,
4+
15,
5+
31
6+
],
7+
"types": [
8+
[
9+
0,
10+
{
11+
"PrimitiveType": {
12+
"Uint": "U8"
13+
}
14+
}
15+
],
16+
[
17+
1,
18+
{
19+
"PrimitiveType": "Bool"
20+
}
21+
]
22+
],
23+
"typeInfo": {
24+
"StructType": {
25+
"name": "core::ops::RangeInclusive<u8>",
26+
"adt_def": 100,
27+
"fields": [
28+
0,
29+
0,
30+
1
31+
],
32+
"layout": {
33+
"fields": {
34+
"Arbitrary": {
35+
"offsets": [
36+
{
37+
"num_bits": 8
38+
},
39+
{
40+
"num_bits": 16
41+
},
42+
{
43+
"num_bits": 0
44+
}
45+
]
46+
}
47+
},
48+
"variants": {
49+
"Single": {
50+
"index": 0
51+
}
52+
},
53+
"abi": {
54+
"Aggregate": {
55+
"sized": true
56+
}
57+
},
58+
"abi_align": 1,
59+
"size": {
60+
"num_bits": 24
61+
}
62+
}
63+
}
64+
}
65+
}

kmir/src/tests/integration/test_decode_value.py

Lines changed: 27 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -52,14 +52,16 @@ def definition(definition_dir: Path) -> KDefinition:
5252
from pyk.kast.outer import read_kast_definition
5353

5454
res = read_kast_definition(definition_dir / 'compiled.json')
55+
_patch_definition(res)
56+
return res
57+
5558

59+
def _patch_definition(definition: KDefinition) -> None:
5660
# Monkey patch __repr__ on the fixture to avoid flooding the output on test failure
57-
cls = res.__class__
61+
cls = definition.__class__
5862
new_repr = lambda self: repr('KMIR LLVM definition')
5963
new_cls = type(f'{cls.__name__}WithCustomRepr', (cls,), {'__repr__': new_repr})
60-
object.__setattr__(res, '__class__', new_cls)
61-
62-
return res
64+
object.__setattr__(definition, '__class__', new_cls)
6365

6466

6567
def dedent(s: str) -> str:
@@ -181,6 +183,7 @@ def load_test_types():
181183
'enum-option-nonzero-none',
182184
'enum-option-nonzero-some',
183185
'str',
186+
'struct-simple-permuted-fields',
184187
)
185188

186189

@@ -232,15 +235,31 @@ def test_decode_value(
232235
assert test_data.expected == actual
233236

234237

238+
@pytest.fixture(scope='module')
239+
def kmir_definition_dir() -> Path:
240+
from kmir.build import LLVM_DEF_DIR
241+
242+
return LLVM_DEF_DIR
243+
244+
245+
@pytest.fixture(scope='module')
246+
def kmir_definition(kmir_definition_dir: Path) -> KDefinition:
247+
from pyk.kast.outer import read_kast_definition
248+
249+
res = read_kast_definition(kmir_definition_dir / 'compiled.json')
250+
_patch_definition(res)
251+
return res
252+
253+
235254
@pytest.mark.parametrize(
236255
'test_data',
237256
TEST_DATA,
238257
ids=[test_id for test_id, *_ in TEST_DATA],
239258
)
240259
def test_python_decode_value(
241260
test_data: _TestData,
242-
definition_dir: Path,
243-
definition: KDefinition,
261+
kmir_definition_dir: Path,
262+
kmir_definition: KDefinition,
244263
tmp_path: Path,
245264
) -> None:
246265
from pyk.kast.inner import KSort
@@ -260,9 +279,9 @@ def test_python_decode_value(
260279
types=types,
261280
)
262281
kast = value.to_kast()
263-
kore = kast_to_kore(definition, kast, KSort('Value'))
282+
kore = kast_to_kore(kmir_definition, kast, KSort('Value'))
264283
actual = kore_print(
265-
definition_dir=definition_dir,
284+
definition_dir=kmir_definition_dir,
266285
pattern=kore,
267286
output='pretty',
268287
)

0 commit comments

Comments
 (0)