Skip to content

Commit 6f147e5

Browse files
committed
exir: add flatbuffer-to-program reader
This continues the work from #17333. Change-Id: I35ac4cd5f6430ea89939453344c13e056b5c746c Signed-off-by: Chizkiyahu Raful <chizkiyahu.raful@arm.com>
1 parent a279b72 commit 6f147e5

4 files changed

Lines changed: 231 additions & 22 deletions

File tree

exir/_serialize/_flatbuffer_program.py

Lines changed: 135 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -8,12 +8,14 @@
88
import enum
99
import functools
1010
import importlib
11+
import pkgutil
1112
import tempfile
1213

1314
from contextvars import ContextVar
1415
from dataclasses import fields, is_dataclass
1516
from functools import lru_cache
16-
from typing import Any, Dict, Optional
17+
from types import ModuleType
18+
from typing import Any, Dict, get_args, get_origin, get_type_hints, Optional, Union
1719

1820
import flatbuffers # pyre-ignore[21]
1921
from executorch.exir._serialize._flatbuffer import (
@@ -22,6 +24,7 @@
2224
_prepare_schema,
2325
_SchemaInfo,
2426
)
27+
from executorch.exir._serialize.generated import executorch_flatbuffer as _generated_fb
2528
from executorch.exir._serialize.generated.executorch_flatbuffer import (
2629
BackendDelegateInlineData as _BackendDelegateInlineData,
2730
Buffer as _Buffer,
@@ -33,6 +36,7 @@
3336

3437
_T_CLASS_CACHE: Dict[type, type] = {}
3538
_FIELD_NAME_CACHE: Dict[type, tuple[tuple[str, str], ...]] = {}
39+
_TYPE_HINTS_CACHE: Dict[type, Dict[str, Any]] = {}
3640
_BUFFER_ALIGNMENT: ContextVar[int] = ContextVar("_BUFFER_ALIGNMENT", default=1)
3741
_DELEGATE_ALIGNMENT: ContextVar[int] = ContextVar("_DELEGATE_ALIGNMENT", default=1)
3842

@@ -64,6 +68,15 @@ def _dataclass_field_map(dataclass_type: type) -> tuple[tuple[str, str], ...]:
6468
return mapping
6569

6670

71+
def _dataclass_type_hints(dataclass_type: type) -> Dict[str, Any]:
72+
cached = _TYPE_HINTS_CACHE.get(dataclass_type)
73+
if cached is not None:
74+
return cached
75+
type_hints = get_type_hints(dataclass_type)
76+
_TYPE_HINTS_CACHE[dataclass_type] = type_hints
77+
return type_hints
78+
79+
6780
def _create_aligned_byte_vector(builder: Any, data: bytes, alignment: int) -> int:
6881
if not _is_valid_alignment(alignment):
6982
raise ValueError(f"Bad alignment {alignment}")
@@ -194,6 +207,126 @@ def convert_program(val: Program) -> ProgramT:
194207
return _convert_dataclass(val)
195208

196209

210+
# The generated FlatBuffer Python modules import child tables/unions as modules
211+
# (for example, Program.ExecutionPlan becomes the ExecutionPlan module), but the
212+
# unpacking helpers later expect those globals to be the corresponding classes.
213+
# Rebind module globals like ExecutionPlan -> ExecutionPlan.ExecutionPlan so the
214+
# generated InitFromObj()/InitFromPackedBuf() code can instantiate nested types.
215+
def _patch_generated_module_aliases(module: ModuleType) -> None:
216+
for name, maybe_module in vars(module).items():
217+
if not isinstance(maybe_module, ModuleType):
218+
continue
219+
maybe_class = getattr(maybe_module, name, None)
220+
if isinstance(maybe_class, type):
221+
setattr(module, name, maybe_class)
222+
223+
224+
@lru_cache(maxsize=1)
225+
def _patch_generated_flatbuffer_aliases() -> None:
226+
package_name = _generated_fb.__name__
227+
for module_info in pkgutil.iter_modules(_generated_fb.__path__):
228+
module = importlib.import_module(f"{package_name}.{module_info.name}")
229+
_patch_generated_module_aliases(module)
230+
231+
232+
def _flatbuffer_dataclass_names(val: Any) -> tuple[str, Optional[str]]:
233+
val_type_name = type(val).__name__
234+
if val_type_name.endswith("T"):
235+
return val_type_name, val_type_name[:-1]
236+
return val_type_name, None
237+
238+
239+
def _matches_dataclass_union_type(
240+
union_type: Any, val_type_name: str, val_dataclass_name: Optional[str]
241+
) -> bool:
242+
if not is_dataclass(union_type):
243+
return False
244+
union_name = union_type.__name__
245+
return union_name == val_type_name or (
246+
val_dataclass_name is not None and union_name == val_dataclass_name
247+
)
248+
249+
250+
def _matches_non_dataclass_union_type(union_type: Any, val: Any) -> bool:
251+
if union_type is Any:
252+
return True
253+
if union_type is str and isinstance(val, (bytes, bytearray, memoryview)):
254+
return True
255+
union_origin = get_origin(union_type)
256+
if union_origin is list and hasattr(val, "__iter__"):
257+
return True
258+
return isinstance(union_type, type) and isinstance(val, union_type)
259+
260+
261+
def _union_choice_from_value(union_types: tuple[Any, ...], val: Any) -> Any:
262+
if val is None:
263+
for union_type in union_types:
264+
if union_type is type(None):
265+
return union_type
266+
return None
267+
268+
val_type_name, val_dataclass_name = _flatbuffer_dataclass_names(val)
269+
270+
for union_type in union_types:
271+
if union_type is type(None):
272+
continue
273+
if _matches_dataclass_union_type(union_type, val_type_name, val_dataclass_name):
274+
return union_type
275+
if _matches_non_dataclass_union_type(union_type, val):
276+
return union_type
277+
return None
278+
279+
280+
def _convert_from_flatbuffer_value(val: Any, expected_type: Any) -> Any:
281+
if val is None:
282+
return None
283+
284+
origin = get_origin(expected_type)
285+
if origin is list:
286+
item_type = get_args(expected_type)[0]
287+
return [_convert_from_flatbuffer_value(item, item_type) for item in val]
288+
289+
if origin is Union:
290+
union_type = _union_choice_from_value(get_args(expected_type), val)
291+
if union_type is None:
292+
raise TypeError(
293+
f"Could not match value type {type(val)} to {expected_type}"
294+
)
295+
if union_type is type(None):
296+
return None
297+
return _convert_from_flatbuffer_value(val, union_type)
298+
299+
if expected_type is bytes:
300+
return _coerce_bytes(val)
301+
if expected_type is str and isinstance(val, (bytes, bytearray, memoryview)):
302+
return _coerce_bytes(val).decode("utf-8")
303+
if is_dataclass(expected_type):
304+
return _convert_from_flatbuffer_dataclass(val, expected_type)
305+
if isinstance(expected_type, type) and issubclass(expected_type, enum.Enum):
306+
if isinstance(val, expected_type):
307+
return val
308+
return expected_type(val)
309+
if isinstance(expected_type, type):
310+
return expected_type(val)
311+
return val
312+
313+
314+
def _convert_from_flatbuffer_dataclass(val: Any, dataclass_type: type) -> Any:
315+
result = {}
316+
type_hints = _dataclass_type_hints(dataclass_type)
317+
for src_name, dst_name in _dataclass_field_map(dataclass_type):
318+
result[src_name] = _convert_from_flatbuffer_value(
319+
getattr(val, dst_name), type_hints[src_name]
320+
)
321+
return dataclass_type(**result)
322+
323+
324+
def _flatbuffer_to_program(program_data: bytes) -> Program:
325+
_patch_generated_flatbuffer_aliases()
326+
program_t = ProgramT.InitFromPackedBuf(program_data)
327+
return _convert_from_flatbuffer_dataclass(program_t, Program)
328+
329+
197330
@lru_cache(maxsize=1)
198331
def _get_schema_info(
199332
constant_tensor_alignment: Optional[int], delegate_alignment: Optional[int]
@@ -213,11 +346,7 @@ def _program_to_flatbuffer(
213346
constant_tensor_alignment: Optional[int] = None,
214347
delegate_alignment: Optional[int] = None,
215348
) -> _FlatbufferResult:
216-
"""Converts a Program dataclass into binary flatbuffer data.
217-
218-
Unlike _program_json_to_flatbuffer(), this does not use JSON or invoke
219-
flatc to build the binary.
220-
"""
349+
"""Converts a Program dataclass into binary flatbuffer data."""
221350
schema_info = _get_schema_info(constant_tensor_alignment, delegate_alignment)
222351
_set_pack_alignments(schema_info.tensor_alignment, schema_info.delegate_alignment)
223352
_install_fast_packers()

exir/_serialize/_program.py

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -17,11 +17,11 @@
1717

1818
from executorch.exir._serialize._cord import Cord
1919
from executorch.exir._serialize._dataclass import _DataclassEncoder, _json_to_dataclass
20-
from executorch.exir._serialize._flatbuffer import (
21-
_FlatbufferResult,
22-
_program_flatbuffer_to_json,
20+
from executorch.exir._serialize._flatbuffer import _FlatbufferResult
21+
from executorch.exir._serialize._flatbuffer_program import (
22+
_flatbuffer_to_program,
23+
_program_to_flatbuffer,
2324
)
24-
from executorch.exir._serialize._flatbuffer_program import _program_to_flatbuffer
2525
from executorch.exir._serialize._named_data_store import (
2626
NamedDataStore,
2727
NamedDataStoreOutput,
@@ -757,9 +757,7 @@ def deserialize_pte_binary(program_data: bytes) -> PTEFile:
757757
segment_base_offset = eh.segment_base_offset
758758

759759
# Parse the flatbuffer data.
760-
program: Program = _json_to_program(
761-
_program_flatbuffer_to_json(program_data[:program_size])
762-
)
760+
program: Program = _flatbuffer_to_program(program_data[:program_size])
763761

764762
if segment_base_offset != 0:
765763
# Move segment data back into the Program.

exir/_serialize/test/test_flatbuffer_program.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,10 @@
1111
_program_flatbuffer_to_json,
1212
_program_json_to_flatbuffer,
1313
)
14-
from executorch.exir._serialize._flatbuffer_program import _program_to_flatbuffer
14+
from executorch.exir._serialize._flatbuffer_program import (
15+
_flatbuffer_to_program,
16+
_program_to_flatbuffer,
17+
)
1518
from executorch.exir._serialize._program import _json_to_program, _program_to_json
1619
from executorch.exir.backend.compile_spec_schema import CompileSpec
1720
from executorch.exir.schema import (
@@ -172,6 +175,13 @@ def test_roundtrip_via_json(self) -> None:
172175
program2 = _json_to_program(_program_flatbuffer_to_json(result.data))
173176
self.assertEqual(program2, program)
174177

178+
def test_roundtrip_via_direct_python(self) -> None:
179+
program = self._make_program()
180+
result = _program_to_flatbuffer(
181+
program, constant_tensor_alignment=32, delegate_alignment=64
182+
)
183+
self.assertEqual(_flatbuffer_to_program(result.data), program)
184+
175185
def test_flatbuffer_paths_match(self) -> None:
176186
program = self._make_program()
177187
cases = [

0 commit comments

Comments
 (0)