88import enum
99import functools
1010import importlib
11+ import pkgutil
1112import tempfile
1213
1314from contextvars import ContextVar
1415from dataclasses import fields , is_dataclass
1516from functools import lru_cache
16- from typing import Any , Dict , Optional
17+ from types import ModuleType
18+ from typing import Any , Dict , get_args , get_origin , get_type_hints , Optional , Union
1719
1820import flatbuffers # pyre-ignore[21]
1921from executorch .exir ._serialize ._flatbuffer import (
2224 _prepare_schema ,
2325 _SchemaInfo ,
2426)
27+ from executorch .exir ._serialize .generated import executorch_flatbuffer as _generated_fb
2528from executorch .exir ._serialize .generated .executorch_flatbuffer import (
2629 BackendDelegateInlineData as _BackendDelegateInlineData ,
2730 Buffer as _Buffer ,
3336
3437_T_CLASS_CACHE : Dict [type , type ] = {}
3538_FIELD_NAME_CACHE : Dict [type , tuple [tuple [str , str ], ...]] = {}
39+ _TYPE_HINTS_CACHE : Dict [type , Dict [str , Any ]] = {}
3640_BUFFER_ALIGNMENT : ContextVar [int ] = ContextVar ("_BUFFER_ALIGNMENT" , default = 1 )
3741_DELEGATE_ALIGNMENT : ContextVar [int ] = ContextVar ("_DELEGATE_ALIGNMENT" , default = 1 )
3842
@@ -64,6 +68,15 @@ def _dataclass_field_map(dataclass_type: type) -> tuple[tuple[str, str], ...]:
6468 return mapping
6569
6670
71+ def _dataclass_type_hints (dataclass_type : type ) -> Dict [str , Any ]:
72+ cached = _TYPE_HINTS_CACHE .get (dataclass_type )
73+ if cached is not None :
74+ return cached
75+ type_hints = get_type_hints (dataclass_type )
76+ _TYPE_HINTS_CACHE [dataclass_type ] = type_hints
77+ return type_hints
78+
79+
6780def _create_aligned_byte_vector (builder : Any , data : bytes , alignment : int ) -> int :
6881 if not _is_valid_alignment (alignment ):
6982 raise ValueError (f"Bad alignment { alignment } " )
@@ -194,6 +207,126 @@ def convert_program(val: Program) -> ProgramT:
194207 return _convert_dataclass (val )
195208
196209
210+ # The generated FlatBuffer Python modules import child tables/unions as modules
211+ # (for example, Program.ExecutionPlan becomes the ExecutionPlan module), but the
212+ # unpacking helpers later expect those globals to be the corresponding classes.
213+ # Rebind module globals like ExecutionPlan -> ExecutionPlan.ExecutionPlan so the
214+ # generated InitFromObj()/InitFromPackedBuf() code can instantiate nested types.
215+ def _patch_generated_module_aliases (module : ModuleType ) -> None :
216+ for name , maybe_module in vars (module ).items ():
217+ if not isinstance (maybe_module , ModuleType ):
218+ continue
219+ maybe_class = getattr (maybe_module , name , None )
220+ if isinstance (maybe_class , type ):
221+ setattr (module , name , maybe_class )
222+
223+
224+ @lru_cache (maxsize = 1 )
225+ def _patch_generated_flatbuffer_aliases () -> None :
226+ package_name = _generated_fb .__name__
227+ for module_info in pkgutil .iter_modules (_generated_fb .__path__ ):
228+ module = importlib .import_module (f"{ package_name } .{ module_info .name } " )
229+ _patch_generated_module_aliases (module )
230+
231+
232+ def _flatbuffer_dataclass_names (val : Any ) -> tuple [str , Optional [str ]]:
233+ val_type_name = type (val ).__name__
234+ if val_type_name .endswith ("T" ):
235+ return val_type_name , val_type_name [:- 1 ]
236+ return val_type_name , None
237+
238+
239+ def _matches_dataclass_union_type (
240+ union_type : Any , val_type_name : str , val_dataclass_name : Optional [str ]
241+ ) -> bool :
242+ if not is_dataclass (union_type ):
243+ return False
244+ union_name = union_type .__name__
245+ return union_name == val_type_name or (
246+ val_dataclass_name is not None and union_name == val_dataclass_name
247+ )
248+
249+
250+ def _matches_non_dataclass_union_type (union_type : Any , val : Any ) -> bool :
251+ if union_type is Any :
252+ return True
253+ if union_type is str and isinstance (val , (bytes , bytearray , memoryview )):
254+ return True
255+ union_origin = get_origin (union_type )
256+ if union_origin is list and hasattr (val , "__iter__" ):
257+ return True
258+ return isinstance (union_type , type ) and isinstance (val , union_type )
259+
260+
261+ def _union_choice_from_value (union_types : tuple [Any , ...], val : Any ) -> Any :
262+ if val is None :
263+ for union_type in union_types :
264+ if union_type is type (None ):
265+ return union_type
266+ return None
267+
268+ val_type_name , val_dataclass_name = _flatbuffer_dataclass_names (val )
269+
270+ for union_type in union_types :
271+ if union_type is type (None ):
272+ continue
273+ if _matches_dataclass_union_type (union_type , val_type_name , val_dataclass_name ):
274+ return union_type
275+ if _matches_non_dataclass_union_type (union_type , val ):
276+ return union_type
277+ return None
278+
279+
280+ def _convert_from_flatbuffer_value (val : Any , expected_type : Any ) -> Any :
281+ if val is None :
282+ return None
283+
284+ origin = get_origin (expected_type )
285+ if origin is list :
286+ item_type = get_args (expected_type )[0 ]
287+ return [_convert_from_flatbuffer_value (item , item_type ) for item in val ]
288+
289+ if origin is Union :
290+ union_type = _union_choice_from_value (get_args (expected_type ), val )
291+ if union_type is None :
292+ raise TypeError (
293+ f"Could not match value type { type (val )} to { expected_type } "
294+ )
295+ if union_type is type (None ):
296+ return None
297+ return _convert_from_flatbuffer_value (val , union_type )
298+
299+ if expected_type is bytes :
300+ return _coerce_bytes (val )
301+ if expected_type is str and isinstance (val , (bytes , bytearray , memoryview )):
302+ return _coerce_bytes (val ).decode ("utf-8" )
303+ if is_dataclass (expected_type ):
304+ return _convert_from_flatbuffer_dataclass (val , expected_type )
305+ if isinstance (expected_type , type ) and issubclass (expected_type , enum .Enum ):
306+ if isinstance (val , expected_type ):
307+ return val
308+ return expected_type (val )
309+ if isinstance (expected_type , type ):
310+ return expected_type (val )
311+ return val
312+
313+
314+ def _convert_from_flatbuffer_dataclass (val : Any , dataclass_type : type ) -> Any :
315+ result = {}
316+ type_hints = _dataclass_type_hints (dataclass_type )
317+ for src_name , dst_name in _dataclass_field_map (dataclass_type ):
318+ result [src_name ] = _convert_from_flatbuffer_value (
319+ getattr (val , dst_name ), type_hints [src_name ]
320+ )
321+ return dataclass_type (** result )
322+
323+
324+ def _flatbuffer_to_program (program_data : bytes ) -> Program :
325+ _patch_generated_flatbuffer_aliases ()
326+ program_t = ProgramT .InitFromPackedBuf (program_data )
327+ return _convert_from_flatbuffer_dataclass (program_t , Program )
328+
329+
197330@lru_cache (maxsize = 1 )
198331def _get_schema_info (
199332 constant_tensor_alignment : Optional [int ], delegate_alignment : Optional [int ]
@@ -213,11 +346,7 @@ def _program_to_flatbuffer(
213346 constant_tensor_alignment : Optional [int ] = None ,
214347 delegate_alignment : Optional [int ] = None ,
215348) -> _FlatbufferResult :
216- """Converts a Program dataclass into binary flatbuffer data.
217-
218- Unlike _program_json_to_flatbuffer(), this does not use JSON or invoke
219- flatc to build the binary.
220- """
349+ """Converts a Program dataclass into binary flatbuffer data."""
221350 schema_info = _get_schema_info (constant_tensor_alignment , delegate_alignment )
222351 _set_pack_alignments (schema_info .tensor_alignment , schema_info .delegate_alignment )
223352 _install_fast_packers ()
0 commit comments