Skip to content

Commit 1bf01ef

Browse files
authored
Implement linking for Memory allocations (#673)
Closes #631
1 parent 8cef90e commit 1bf01ef

2 files changed

Lines changed: 138 additions & 37 deletions

File tree

kmir/src/kmir/linker.py

Lines changed: 49 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -19,12 +19,13 @@ def link(smirs: list[SMIRInfo]) -> SMIRInfo:
1919

2020
_LOGGER.info(f'Maximum type ID (offset) is {offset}, linking {len(smirs)} smir.json files')
2121

22-
for smir, offset in zip(smirs, [offset * i for i in range(len(smirs))], strict=True):
23-
_LOGGER.debug(f'Offset {offset} for smir {smir._smir["name"]}')
24-
apply_offset(smir, offset)
22+
for i, smir in enumerate(smirs):
23+
smir_offset = offset * i
24+
_LOGGER.debug(f'Offset {smir_offset} for smir {smir._smir["name"]}')
25+
apply_offset(smir, smir_offset)
2526

2627
result_dict = {
27-
'name': ','.join([smir._smir['name'] for smir in smirs]),
28+
'name': ','.join(smir._smir['name'] for smir in smirs),
2829
'crate_id': 0, # HACK
2930
'allocs': [a for smir in smirs for a in smir._smir['allocs']],
3031
'functions': [f for smir in smirs for f in smir._smir['functions']],
@@ -38,10 +39,7 @@ def link(smirs: list[SMIRInfo]) -> SMIRInfo:
3839

3940

4041
def id_range(smir: SMIRInfo) -> int:
41-
f_max = max([0] + list(smir.function_symbols.keys()))
42-
ty_max = max([0] + list(smir.types.keys()))
43-
span_range = max([0] + list(smir.spans.keys()))
44-
return max(f_max, ty_max, span_range)
42+
return max(0, *smir.function_symbols, *smir.types, *smir.spans, *smir.allocs)
4543

4644

4745
def apply_offset(info: SMIRInfo, offset: int) -> None:
@@ -58,8 +56,15 @@ def apply_offset(info: SMIRInfo, offset: int) -> None:
5856
]
5957
info._smir['spans'] = [(i + offset, span) for i, span in info._smir['spans']]
6058

61-
# TODO adjust all alloc IDs (incl. alloc provenance)
62-
# TODO then adjust alloc references during item traversal
59+
for alloc in dic['allocs']: # alloc: AllocInfo
60+
alloc['alloc_id'] += offset
61+
alloc['ty'] += offset
62+
global_alloc = alloc['global_alloc'] # global_alloc: GlobalAlloc
63+
match global_alloc:
64+
case {'Memory': allocation}: # global_alloc: Memory, allocation: Allocation
65+
apply_offset_provenance(allocation['provenance'], offset)
66+
case _:
67+
raise ValueError('Unsupported or invalid GlobalAlloc data: {global_alloc}')
6368

6469
# traverse item bodies and replace all `ty` fields
6570
for item in info._smir['items']:
@@ -71,21 +76,21 @@ def apply_offset_typeInfo(typeinfo: dict, offset: int) -> dict:
7176
# returns the updated (i.e., mutated) `typeinfo`` dictionary
7277
# 'PrimitiveType' in typeinfo:
7378
if 'EnumType' in typeinfo:
74-
typeinfo['EnumType']['adt_def'] = typeinfo['EnumType']['adt_def'] + offset
79+
typeinfo['EnumType']['adt_def'] += offset
7580
typeinfo['EnumType']['fields'] = [[x + offset for x in l] for l in typeinfo['EnumType']['fields']]
7681
elif 'StructType' in typeinfo:
7782
typeinfo['StructType']['fields'] = [x + offset for x in typeinfo['StructType']['fields']]
78-
typeinfo['StructType']['adt_def'] = typeinfo['StructType']['adt_def'] + offset
83+
typeinfo['StructType']['adt_def'] += offset
7984
elif 'UnionType' in typeinfo:
80-
typeinfo['UnionType']['adt_def'] = typeinfo['UnionType']['adt_def'] + offset
85+
typeinfo['UnionType']['adt_def'] += offset
8186
elif 'ArrayType' in typeinfo:
82-
typeinfo['ArrayType']['elem_type'] = typeinfo['ArrayType']['elem_type'] + offset
87+
typeinfo['ArrayType']['elem_type'] += offset
8388
if 'size' in typeinfo['ArrayType'] and typeinfo['ArrayType']['size'] is not None:
8489
apply_offset_tyconst(typeinfo['ArrayType']['size']['kind'], offset)
8590
elif 'PtrType' in typeinfo:
86-
typeinfo['PtrType']['pointee_type'] = typeinfo['PtrType']['pointee_type'] + offset
91+
typeinfo['PtrType']['pointee_type'] += offset
8792
elif 'RefType' in typeinfo:
88-
typeinfo['RefType']['pointee_type'] = typeinfo['RefType']['pointee_type'] + offset
93+
typeinfo['RefType']['pointee_type'] += offset
8994
elif 'TupleType' in typeinfo:
9095
typeinfo['TupleType']['types'] = [x + offset for x in typeinfo['TupleType']['types']]
9196
# 'FunType' in typeinfo:
@@ -100,14 +105,14 @@ def apply_offset_item(item: dict, offset: int) -> None:
100105
if 'MonoItemFn' in item and 'body' in item['MonoItemFn']:
101106
body = item['MonoItemFn']['body']
102107
for local in body['locals']:
103-
local['ty'] = local['ty'] + offset
104-
local['span'] = local['span'] + offset
108+
local['ty'] += offset
109+
local['span'] += offset
105110
for block in body['blocks']:
106111
for stmt in block['statements']:
107112
apply_offset_stmt(stmt['kind'], offset)
108-
stmt['span'] = stmt['span'] + offset
113+
stmt['span'] += offset
109114
apply_offset_terminator(block['terminator']['kind'], offset)
110-
block['terminator']['span'] = block['terminator']['span'] + offset
115+
block['terminator']['span'] += offset
111116
# adjust span in var_debug_info, each item's source_info.span
112117
for thing in body['var_debug_info']:
113118
thing['source_info']['span'] += offset
@@ -146,10 +151,13 @@ def apply_offset_operand(op: dict, offset: int) -> None:
146151
elif 'Move' in op:
147152
apply_offset_place(op['Move'], offset)
148153
elif 'Constant' in op:
149-
op['Constant']['const_']['ty'] = op['Constant']['const_']['ty'] + offset
150-
if 'Ty' in op['Constant']['const_']['kind']:
151-
apply_offset_tyconst(op['Constant']['const_']['kind']['Ty']['kind'], offset)
152-
op['Constant']['span'] = op['Constant']['span'] + offset
154+
op['Constant']['const_']['ty'] += offset
155+
match op['Constant']['const_']['kind']:
156+
case {'Ty': val}:
157+
apply_offset_tyconst(val['kind'], offset)
158+
case {'Allocated': val}:
159+
apply_offset_provenance(val['provenance'], offset)
160+
op['Constant']['span'] += offset
153161

154162

155163
def apply_offset_tyconst(tyconst: dict, offset: int) -> None:
@@ -159,9 +167,14 @@ def apply_offset_tyconst(tyconst: dict, offset: int) -> None:
159167
for arg in tyconst['Unevaluated'][1]:
160168
apply_offset_gen_arg(arg, offset)
161169
elif 'Value' in tyconst:
162-
tyconst['Value'][0] = tyconst['Value'][0] + offset
170+
tyconst['Value'][0] += offset
163171
elif 'ZSTValue' in tyconst:
164-
tyconst['ZSTValue'] = tyconst['ZSTValue'] + offset
172+
tyconst['ZSTValue'] += offset
173+
174+
175+
def apply_offset_provenance(provenance: dict, offset: int) -> None:
176+
for i in range(len(provenance['ptrs'])):
177+
provenance['ptrs'][i][1] += offset
165178

166179

167180
def apply_offset_place(place: dict, offset: int) -> None:
@@ -173,15 +186,15 @@ def apply_offset_place(place: dict, offset: int) -> None:
173186
def apply_offset_proj(proj: dict, offset: int) -> None:
174187
# Deref
175188
if 'Field' in proj:
176-
proj['Field'][1] = proj['Field'][1] + offset
189+
proj['Field'][1] += offset
177190
# Index
178191
# ConstantIndex
179192
# Subslice
180193
# Downcast
181194
elif 'OpaqueCast' in proj:
182-
proj['OpaqueCast'] = proj['OpaqueCast'] + offset
195+
proj['OpaqueCast'] += offset
183196
elif 'Subtype' in proj:
184-
proj['Subtype'] = proj['Subtype'] + offset
197+
proj['Subtype'] += offset
185198

186199

187200
def apply_offset_stmt(stmt: dict, offset: int) -> None:
@@ -216,10 +229,10 @@ def apply_offset_rvalue(rval: dict, offset: int) -> None:
216229
elif 'Aggregate' in rval:
217230
# handle AggregateKind
218231
if 'Array' in rval['Aggregate'][0]:
219-
rval['Aggregate'][0]['Array'] = rval['Aggregate'][0]['Array'] + offset # ty field
232+
rval['Aggregate'][0]['Array'] += offset # ty field
220233
# Tuple
221234
elif 'Adt' in rval['Aggregate'][0]:
222-
rval['Aggregate'][0]['Adt'][0] = rval['Aggregate'][0]['Adt'][0] + offset # AdtDef field
235+
rval['Aggregate'][0]['Adt'][0] += offset # AdtDef field
223236
# GenericArgs can recursively contain TyConst, or Ty
224237
for arg in rval['Aggregate'][0]['Adt'][2]:
225238
apply_offset_gen_arg(arg, offset)
@@ -231,15 +244,15 @@ def apply_offset_rvalue(rval: dict, offset: int) -> None:
231244
for arg in rval['Aggregate'][0]['Coroutine'][1]:
232245
apply_offset_gen_arg(arg, offset)
233246
elif 'RawPtr' in rval['Aggregate'][0]:
234-
rval['Aggregate'][0]['RawPtr'][0] = rval['Aggregate'][0]['RawPtr'][0] + offset # ty field
247+
rval['Aggregate'][0]['RawPtr'][0] += offset # ty field
235248
for op in rval['Aggregate'][1]:
236249
apply_offset_operand(op, offset)
237250
elif 'BinaryOp' in rval:
238251
apply_offset_operand(rval['BinaryOp'][1], offset)
239252
apply_offset_operand(rval['BinaryOp'][2], offset)
240253
elif 'Cast' in rval:
241254
apply_offset_operand(rval['Cast'][1], offset)
242-
rval['Cast'][2] = rval['Cast'][2] + offset
255+
rval['Cast'][2] += offset
243256
elif 'CheckedBinaryOp' in rval:
244257
apply_offset_operand(rval['CheckedBinaryOp'][1], offset)
245258
apply_offset_operand(rval['CheckedBinaryOp'][2], offset)
@@ -256,10 +269,10 @@ def apply_offset_rvalue(rval: dict, offset: int) -> None:
256269
apply_offset_tyconst(rval['Repeat'][1]['kind'], offset)
257270
elif 'ShallowInitBox' in rval:
258271
apply_offset_operand(rval['ShallowInitBox'][0], offset)
259-
rval['ShallowInitBox'][1] = rval['ShallowInitBox'][1] + offset
272+
rval['ShallowInitBox'][1] += offset
260273
# ThreadLocalRef
261274
elif 'NullaryOp' in rval:
262-
rval['NullaryOp'][1] = rval['NullaryOp'][1] + offset
275+
rval['NullaryOp'][1] += offset
263276
elif 'UnaryOp' in rval:
264277
apply_offset_operand(rval['UnaryOp'][1], offset)
265278
elif 'Use' in rval:
@@ -269,6 +282,6 @@ def apply_offset_rvalue(rval: dict, offset: int) -> None:
269282
def apply_offset_gen_arg(arg: dict, offset: int) -> None:
270283
# GenericArg may contain a Ty or a TyConst
271284
if 'Type' in arg:
272-
arg['Type'] = arg['Type'] + offset
285+
arg['Type'] += offset
273286
elif 'Const' in arg:
274287
apply_offset_tyconst(arg['Const']['kind'], offset)

kmir/src/kmir/smir.py

Lines changed: 89 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,15 +2,19 @@
22

33
import json
44
import logging
5+
from abc import ABC
56
from collections import deque
67
from dataclasses import dataclass
78
from enum import Enum
89
from functools import cached_property, reduce
9-
from typing import TYPE_CHECKING, Final, NewType
10+
from typing import TYPE_CHECKING, NamedTuple, NewType
1011

1112
if TYPE_CHECKING:
1213
from pathlib import Path
14+
from typing import Any, Final
1315

16+
17+
AllocId = NewType('AllocId', int)
1418
Ty = NewType('Ty', int)
1519
AdtDef = NewType('AdtDef', int)
1620

@@ -31,6 +35,12 @@ def from_file(smir_json_file: Path) -> SMIRInfo:
3135
def dump(self, smir_json_file: Path) -> None:
3236
smir_json_file.write_text(json.dumps(self._smir))
3337

38+
@cached_property
39+
def allocs(self) -> dict[AllocId, AllocInfo]:
40+
return {
41+
alloc_info.alloc_id: alloc_info for alloc_info in (AllocInfo.from_dict(dct) for dct in self._smir['allocs'])
42+
}
43+
3444
@cached_property
3545
def types(self) -> dict[Ty, TypeMetadata]:
3646
return {Ty(id): metadata_from_json(type) for id, type in self._smir['types']}
@@ -369,3 +379,81 @@ def compute_closure(start: Ty, edges: dict[Ty, set[Ty]]) -> set[Ty]:
369379
if next in edges:
370380
work.extend(edges[next])
371381
return reached
382+
383+
384+
@dataclass
385+
class AllocInfo:
386+
alloc_id: AllocId
387+
ty: Ty
388+
global_alloc: GlobalAlloc
389+
390+
@staticmethod
391+
def from_dict(dct: dict[str, Any]) -> AllocInfo:
392+
return AllocInfo(
393+
alloc_id=AllocId(dct['alloc_id']),
394+
ty=Ty(dct['ty']),
395+
global_alloc=GlobalAlloc.from_dict(dct['global_alloc']),
396+
)
397+
398+
399+
class GlobalAlloc(ABC): # noqa: B024
400+
@staticmethod
401+
def from_dict(dct: dict[str, Any]) -> GlobalAlloc:
402+
match dct:
403+
case {'Memory': _}:
404+
return Memory.from_dict(dct)
405+
case _:
406+
raise ValueError('Unsupported or invalid GlobalAlloc data: {dct}')
407+
408+
409+
@dataclass
410+
class Memory(GlobalAlloc):
411+
allocation: Allocation
412+
413+
@staticmethod
414+
def from_dict(dct: dict[str, Any]) -> Memory:
415+
return Memory(
416+
allocation=Allocation.from_dict(dct['Memory']),
417+
)
418+
419+
420+
@dataclass
421+
class Allocation:
422+
data: bytes # field 'bytes'
423+
provenance: ProvenanceMap
424+
align: int
425+
mutable: bool # field 'mutability'
426+
427+
@staticmethod
428+
def from_dict(dct: dict[str, Any]) -> Allocation:
429+
return Allocation(
430+
data=bytes(dct['bytes']),
431+
provenance=ProvenanceMap.from_dict(dct['provenance']),
432+
align=int(dct['align']),
433+
mutable={
434+
'Not': False,
435+
'Mut': True,
436+
}[dct['mutability']],
437+
)
438+
439+
440+
@dataclass
441+
class ProvenanceMap:
442+
ptrs: list[ProvenanceItem]
443+
444+
@staticmethod
445+
def from_dict(dct: dict[str, Any]) -> ProvenanceMap:
446+
return ProvenanceMap(
447+
ptrs=[
448+
ProvenanceItem(
449+
size=int(size),
450+
prov=AllocId(prov),
451+
)
452+
for size, prov in dct['ptrs']
453+
],
454+
)
455+
456+
457+
class ProvenanceItem(NamedTuple):
458+
size: int
459+
prov: AllocId

0 commit comments

Comments
 (0)