|
7 | 7 | from pyk.kast.prelude.string import stringToken |
8 | 8 |
|
9 | 9 | from .alloc import Allocation, AllocInfo, Memory, ProvenanceEntry, ProvenanceMap |
10 | | -from .ty import ArrayT, Bool, EnumT, Int, IntTy, PtrT, RefT, Str, Uint |
| 10 | +from .ty import ( |
| 11 | + ArbitraryFields, |
| 12 | + ArrayT, |
| 13 | + BoolT, |
| 14 | + Direct, |
| 15 | + EnumT, |
| 16 | + Initialized, |
| 17 | + IntT, |
| 18 | + IntTy, |
| 19 | + Multiple, |
| 20 | + PrimitiveInt, |
| 21 | + PtrT, |
| 22 | + RefT, |
| 23 | + Single, |
| 24 | + StrT, |
| 25 | + UintT, |
| 26 | +) |
11 | 27 | from .value import ( |
12 | 28 | NO_METADATA, |
13 | 29 | AggregateValue, |
|
26 | 42 |
|
27 | 43 | from pyk.kast import KInner |
28 | 44 |
|
29 | | - from .ty import Ty, TypeMetadata, UintTy |
| 45 | + from .ty import FieldsShape, LayoutShape, MachineSize, Scalar, TagEncoding, Ty, TypeMetadata, UintTy |
30 | 46 | from .value import Metadata |
31 | 47 |
|
32 | 48 |
|
@@ -126,16 +142,26 @@ def decode_value_or_unable(data: bytes, type_info: TypeMetadata, types: Mapping[ |
126 | 142 |
|
127 | 143 | def decode_value(data: bytes, type_info: TypeMetadata, types: Mapping[Ty, TypeMetadata]) -> Value: |
128 | 144 | match type_info: |
129 | | - case Bool(): |
| 145 | + case BoolT(): |
130 | 146 | return _decode_bool(data) |
131 | | - case Str(): |
| 147 | + case StrT(): |
132 | 148 | return _decode_str(data) |
133 | | - case Uint(int_ty) | Int(int_ty): |
| 149 | + case UintT(int_ty) | IntT(int_ty): |
134 | 150 | return _decode_int(data, int_ty) |
135 | 151 | case ArrayT(elem_ty, length): |
136 | 152 | return _decode_array(data, elem_ty, length, types) |
137 | | - case EnumT(discriminants=discriminants, fields=fields): |
138 | | - return _decode_enum(data, discriminants, fields) |
| 153 | + case EnumT( |
| 154 | + discriminants=discriminants, |
| 155 | + fields=fields, |
| 156 | + layout=layout, |
| 157 | + ): |
| 158 | + return _decode_enum( |
| 159 | + data=data, |
| 160 | + discriminants=discriminants, |
| 161 | + fields=fields, |
| 162 | + layout=layout, |
| 163 | + types=types, |
| 164 | + ) |
139 | 165 | case _: |
140 | 166 | raise ValueError(f'Unsupported type: {type_info}') |
141 | 167 |
|
@@ -195,18 +221,145 @@ def _decode_array( |
195 | 221 |
|
196 | 222 |
|
197 | 223 | def _decode_enum( |
| 224 | + *, |
| 225 | + data: bytes, |
| 226 | + discriminants: list[int], |
| 227 | + fields: list[list[Ty]], |
| 228 | + layout: LayoutShape | None, |
| 229 | + types: Mapping[Ty, TypeMetadata], |
| 230 | +) -> Value: |
| 231 | + if not layout: |
| 232 | + raise ValueError('Enum layout not provided') |
| 233 | + |
| 234 | + offsets = _extract_offsets(layout.fields) |
| 235 | + |
| 236 | + match layout.variants: |
| 237 | + case Single(index): |
| 238 | + return _decode_enum_single( |
| 239 | + data=data, |
| 240 | + discriminants=discriminants, |
| 241 | + fields=fields, |
| 242 | + offsets=offsets, |
| 243 | + # --- |
| 244 | + tag_index=index, |
| 245 | + # --- |
| 246 | + types=types, |
| 247 | + ) |
| 248 | + case Multiple( |
| 249 | + tag=tag, |
| 250 | + tag_encoding=tag_encoding, |
| 251 | + tag_field=tag_field, |
| 252 | + variants=variants, |
| 253 | + ): |
| 254 | + return _decode_enum_multiple( |
| 255 | + data=data, |
| 256 | + discriminants=discriminants, |
| 257 | + fields=fields, |
| 258 | + offsets=offsets, |
| 259 | + # --- |
| 260 | + tag=tag, |
| 261 | + tag_encoding=tag_encoding, |
| 262 | + tag_field=tag_field, |
| 263 | + variant_layouts=variants, |
| 264 | + # --- |
| 265 | + types=types, |
| 266 | + ) |
| 267 | + case _: |
| 268 | + raise AssertionError('Undhandled case') |
| 269 | + |
| 270 | + |
| 271 | +def _extract_offsets(fields_shape: FieldsShape) -> list[MachineSize]: |
| 272 | + match fields_shape: |
| 273 | + case ArbitraryFields(offsets=offsets): |
| 274 | + return offsets |
| 275 | + case _: |
| 276 | + raise ValueError(f'Unsupported fields shape: {fields_shape}') |
| 277 | + |
| 278 | + |
| 279 | +def _decode_enum_single( |
| 280 | + *, |
198 | 281 | data: bytes, |
199 | 282 | discriminants: list[int], |
200 | 283 | fields: list[list[Ty]], |
| 284 | + offsets: list[MachineSize], |
| 285 | + tag_index: int, |
| 286 | + types: Mapping[Ty, TypeMetadata], |
201 | 287 | ) -> Value: |
202 | | - # The only supported case for now is when there are no fields |
203 | | - if any(tys for tys in fields): |
204 | | - raise ValueError('TODO - implement this case') |
| 288 | + assert len(fields) == 1, 'Expected a single list of field types for single-variant enum' |
| 289 | + tys = fields[0] |
| 290 | + |
| 291 | + assert len(discriminants) == 1, 'Expected a single discriminant for single-variant enum' |
| 292 | + discriminant = discriminants[0] |
| 293 | + assert tag_index == discriminant, 'Assumed tag_index to be the same as the discriminant' |
| 294 | + |
| 295 | + field_values = _decode_fields(data=data, tys=tys, offsets=offsets, types=types) |
| 296 | + return AggregateValue(0, field_values) |
| 297 | + |
| 298 | + |
| 299 | +def _decode_enum_multiple( |
| 300 | + *, |
| 301 | + data: bytes, |
| 302 | + discriminants: list[int], |
| 303 | + fields: list[list[Ty]], |
| 304 | + offsets: list[MachineSize], |
| 305 | + # --- |
| 306 | + tag: Scalar, |
| 307 | + tag_encoding: TagEncoding, |
| 308 | + tag_field: int, |
| 309 | + variant_layouts: list[LayoutShape], |
| 310 | + # --- |
| 311 | + types: Mapping[Ty, TypeMetadata], |
| 312 | +) -> Value: |
| 313 | + if not isinstance(tag_encoding, Direct): |
| 314 | + raise ValueError(f'Unsupported encoding: {tag_encoding}') |
| 315 | + |
| 316 | + assert tag_field == 0, 'Assumed tag field to be zero' |
| 317 | + assert len(offsets) == 1, 'Assumed offsets to only contain the tag offset' |
| 318 | + tag_offset = offsets[0] |
| 319 | + tag_value = _extract_tag_value(data=data, tag_offset=tag_offset, tag=tag) |
205 | 320 |
|
206 | | - tag = int.from_bytes(data, byteorder='little', signed=False) |
207 | 321 | try: |
208 | | - variant_idx = discriminants.index(tag) |
| 322 | + variant_idx = discriminants.index(tag_value) |
209 | 323 | except ValueError as err: |
210 | | - raise ValueError(f'Tag not found: {tag}') from err |
| 324 | + raise ValueError(f'Tag not found: {tag_value}') from err |
| 325 | + |
| 326 | + tys = fields[variant_idx] |
| 327 | + |
| 328 | + variant_layout = variant_layouts[variant_idx] |
| 329 | + field_offsets = _extract_offsets(variant_layout.fields) |
| 330 | + assert isinstance(variant_layout.variants, Single) |
211 | 331 |
|
212 | | - return AggregateValue(variant_idx, ()) |
| 332 | + field_values = _decode_fields(data=data, tys=tys, offsets=field_offsets, types=types) |
| 333 | + return AggregateValue(variant_idx, field_values) |
| 334 | + |
| 335 | + |
| 336 | +def _decode_fields( |
| 337 | + *, |
| 338 | + data: bytes, |
| 339 | + tys: list[Ty], |
| 340 | + offsets: list[MachineSize], |
| 341 | + types: Mapping[Ty, TypeMetadata], |
| 342 | +) -> list[Value]: |
| 343 | + res: list[Value] = [] |
| 344 | + for ty, offset in zip(tys, offsets, strict=True): |
| 345 | + type_info = types[ty] |
| 346 | + size_in_bytes = type_info.nbytes(types) |
| 347 | + field_data = data[offset.in_bytes : offset.in_bytes + size_in_bytes] |
| 348 | + value = decode_value(field_data, type_info, types) |
| 349 | + res.append(value) |
| 350 | + return res |
| 351 | + |
| 352 | + |
| 353 | +def _extract_tag_value(*, data: bytes, tag_offset: MachineSize, tag: Scalar) -> int: |
| 354 | + match tag: |
| 355 | + case Initialized( |
| 356 | + value=PrimitiveInt( |
| 357 | + length=length, |
| 358 | + signed=signed, |
| 359 | + ), |
| 360 | + valid_range=_, |
| 361 | + ): |
| 362 | + tag_data = data[tag_offset.in_bytes : tag_offset.in_bytes + length.value] |
| 363 | + return int.from_bytes(tag_data, byteorder='little', signed=signed) |
| 364 | + case _: |
| 365 | + raise ValueError('Unsupported tag: {tag}') |
0 commit comments