forked from pytorch/executorch
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy path_emitter.py
More file actions
2076 lines (1819 loc) · 85 KB
/
_emitter.py
File metadata and controls
2076 lines (1819 loc) · 85 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
"""Takes an ExportedArtifact, or a collection of ExportedArtifacts, in execution dialect, and turns
them into a single ExecuTorch Program.
The provided ExportedArtifact's graph modules are in execution dialect and the emitter parses and
converts them into executorch instructions. The emitter walks the provided graphs and as it
encounters concrete values such as tensors or ints, it converts them to the serialized format and
stores them in a list for later use. The emitter walks the graph by traversing fx.nodes, these can
come in a variety of forms and are the primitives of execution at the graph module level. The most
common 3 we care about are 'call_function', 'place_holder', and 'output'. 'placeholder' and 'output'
handle io for the module and 'call_function' handles everything else. Within 'call_function' we may
encounter an operator or delegate call, in such case we parse the schema and emit all the inputs and
outputs (unless they have already previously been emitted), and then we convert the actual function
call into an executorch instruction such as KernelCall or DelegateCall.
When control flow is present in the graphmodule it will take the form of a few different types of
'call_function'. Today (June 14th 2023) only cond and map are supported. The actual operations of
these, such as the true/false branches of cond, or the mapping function of map, are stored as sub
graphmodules. When these are encountered during emission, the emitter will recursively emit them and
their instructions.
"""
# TODO(jakeszwe): add information here about how weights and other parameters are handled in the
# presence of aot autograd param lifting.
# pyre-strict
import ctypes
import hashlib
import operator
import typing
import warnings
from dataclasses import dataclass, field
from typing import Any, Callable, cast, Dict, List, Mapping, Optional, Tuple, Union
import executorch.exir.memory as memory
import executorch.extension.pytree as ex_pytree
import torch
import torch.fx
from executorch.exir.delegate import executorch_call_delegate, is_lowered_module
from executorch.exir.dialects.backend._ops import BackendOpOverload
from executorch.exir.dialects.edge._ops import EdgeOpOverload
from executorch.exir.error import ExportError, ExportErrorType, InternalError
from executorch.exir.operator.convert import is_out_variant
from executorch.exir.passes.executorch_prim_ops_registry import is_sym_op
from executorch.exir.print_program import _stacktrace_to_framelist, inspect_node
from executorch.exir.schema import (
AllocationDetails,
BackendDelegate,
BackendDelegateDataReference,
BackendDelegateInlineData,
Bool,
BoolList,
Buffer,
Chain,
ContainerMetadata,
DataLocation,
DelegateCall,
Double,
DoubleList,
EValue,
ExecutionPlan,
ExtraTensorInfo,
FreeCall,
Instruction,
Int,
IntList,
JumpFalseCall,
KernelCall,
MoveCall,
Null,
Operator,
OptionalTensorList,
ScalarType,
String,
Tensor,
TensorDataLocation,
TensorList,
TensorShapeDynamism,
)
from executorch.exir.tensor import (
AddressSpaceOverflowException,
dim_order_from_stride,
layout_enum,
make_allocation_info,
make_tensor_value,
memory_format_enum,
scalar_type_enum,
TensorSpec,
)
from executorch.exir.types import LeafValueSpec, ValueSpec
from torch._subclasses.fake_tensor import FakeTensor
from torch.export.exported_program import ExportedProgram, ExportGraphSignature
from torch.fx.node import Node
from torch.utils import _pytree as pytree
from typing_extensions import TypeAlias
@dataclass
class _ProgramState:
"""State shared between all methods of a program and the graph module it represents.
Initialized once within emit_program and then shared across each entry point as they are
emitted.
"""
# Parallel list of specs and the buffers that backed them, have to add + 1 to any index in here
# as index 0 in the constant_buffer is reserved.
allocated_specs: List[TensorSpec] = field(default_factory=list)
# Weights in any arbitrary graph_module only need to compare against weights from previously
# emitted graph modules, not any weights emitted from itself. This should speed up the lookup,
# from O(N) to O(1)
cached_spec_hash_values: Dict[str, int] = field(default_factory=dict)
cached_spec_mutable_hash_values: Dict[str, int] = field(default_factory=dict)
# The 0 index is reserved to be pointed to by non-constant tensors, so add an empty placeholder.
constant_buffer: List[Buffer] = field(default_factory=lambda: [Buffer(storage=b"")])
# The 0 index is reserved to be pointed to by non-constant tensors, so add an empty placeholder.
mutable_buffer: List[Buffer] = field(default_factory=lambda: [Buffer(storage=b"")])
# Delegate data stored directly in the flatbuffer. Pointed to by BackendDelegateDataReference,
# and should be copied to Program.backend_delegate_data.
backend_delegate_data: List[BackendDelegateInlineData] = field(default_factory=list)
# Delegate cache that is used across all entry points. Key is the hash of the delegated payload.
backend_delegate_data_cache: Dict[str, int] = field(default_factory=dict)
# Constants are optionally stored in external files.
# Aggregate unique external constants into one buffer.
external_constant_buffer: List[bytes] = field(default_factory=list)
external_constant_hash: Dict[str, int] = field(default_factory=dict)
# Each constant_tag groups a set of constants together.
# {constant_tag: {fqn: index into external_constant_buffer}}
external_constant_map: Dict[str, Dict[str, int]] = field(default_factory=dict)
@dataclass
class _EmitterState:
"""State of a single emitter.
Local to at least the entry point, and may be local to a subgraph of an entry point originating
from control flow.
"""
values: List[EValue]
operators: List[Operator]
delegates: List[BackendDelegate]
operator_cache: Dict[Tuple[str, str], int]
emit_stacktrace: bool
emit_mutable_buffer_names: bool
spec2id_dict: Dict[TensorSpec, int] = field(default_factory=dict)
def spec2id(self, spec: TensorSpec) -> int:
"""Map a TensorSpec to value index in the values array."""
assert spec in self.spec2id_dict, f"Spec is not found: {spec.debug()}"
return self.spec2id_dict[spec]
@dataclass
class _AbstractValue:
"""Represents an already emitted EValue"""
# Index in the values table of this EValue.
id: int
# Used for sanity checks for functions that expect to only receive AbstractValues.
tensor: Optional[Tensor]
_EmitterValue: TypeAlias = Union[
_AbstractValue, List[_AbstractValue], Tuple[_AbstractValue, ...]
]
_PythonValue: TypeAlias = Union[bool, int, float, torch.Tensor, List["_PythonValue"]]
_SchemaType: TypeAlias = Union[
torch.OptionalType,
torch.ListType,
torch.FloatType,
torch.BoolType,
torch.IntType,
torch.StringType,
torch.TensorType,
]
_Target: TypeAlias = Union[Callable[..., _PythonValue], str]
_Argument: TypeAlias = Union[
_EmitterValue,
Tuple["_Argument", ...],
List["_Argument"],
Dict[str, "_Argument"],
str,
int,
float,
bool,
complex,
torch.dtype,
torch.Tensor,
torch.memory_format,
torch.layout,
None,
]
_DelegateDebugIdentifierMap: TypeAlias = Union[
Dict[int, Tuple[int]], Dict[str, Tuple[int]]
]
class _Emitter(torch.fx.Interpreter):
"""An abstract interpreter (https://wiki.mozilla.org/Abstract_Interpretation) used to emit the
given traced torch.fx.GraphModule to the flatbuffer schema."""
# pyre-ignore[13]: Attribute `node` is never initialized.
node: torch.fx.Node
def __init__(
self,
graph_module: torch.fx.GraphModule,
emitter_state: _EmitterState,
program_state: _ProgramState,
instruction_start_offset: int = 0,
binding_input_values: Optional[List[_AbstractValue]] = None,
binding_output_values: Optional[List[_AbstractValue]] = None,
) -> None:
super().__init__(graph_module)
self.emitter_state = emitter_state
self.program_state = program_state
self.outputs: List[int] = []
self.chain = Chain(
inputs=[],
outputs=[],
instructions=[],
stacktrace=None,
)
if "non_const_buffer_sizes" not in graph_module.meta.keys():
raise RuntimeError(
"Must set 'non_const_buffer_sizes' in graph meta in memory planning pass"
)
self.instruction_start_offset = instruction_start_offset
self.binding_input_values = binding_input_values
self.binding_output_values = binding_output_values
self.graph_module: torch.fx.GraphModule = graph_module
self.nodes: List[torch.fx.Node] = list(self.graph_module.graph.nodes)
# Marks the placeholder node with its order so that we can match them with the corresponding
# Abstract Value coming from top level.
self.placeholder_count = 0
self.concrete_output_ids: List[_AbstractValue] = []
self.debug_handle_map: Dict[int, Union[int, List[int]]] = {}
self.instruction_id_to_num_outs_map: Dict[int, int] = {}
self.instr_id_to_delegate_debug_id_map: Dict[
int, Dict[str, Union[str, _DelegateDebugIdentifierMap]]
] = {}
def _emit_node_specific_error(self, node: torch.fx.Node, err_msg: str) -> str:
"""Returns 'err_msg' with node specific information attached."""
err_msg = f"Failed with error: {str(err_msg)}\n" + inspect_node(
self.graph_module.graph, node
)
return err_msg
def _internal_assert_emitter(
self, pred: bool, node: torch.fx.Node, assert_msg: str
) -> None:
"""If pred is False, construct and raise a node specific error message."""
if not pred:
raise InternalError(self._emit_node_specific_error(node, assert_msg))
def _emit_int_list(self, val: List[_Argument]) -> EValue:
"""Emits a list of integers as a collection of EValues.
For every argument in 'val':
- If it is a concrete value, then emit it and then place its location in the boxed list
- If it is already an abstract value, then just place its location in the boxed list
Int lists are boxed to handle symints whose values are determined at runtime, but could
still end up inside lists for ops like view_copy(Tensor self, SymInt[] size)
"""
boxed_list = []
for item in val:
if isinstance(item, _AbstractValue):
boxed_list.append(item.id)
elif isinstance(item, int):
boxed_list.append(
self._emit_evalue(self._constant_to_evalue(item, None)).id
)
else:
self._internal_assert_emitter(
False, self.node, "Unsupported type encountered in int list."
)
return EValue(IntList(boxed_list))
def _emit_list(self, val: List[_Argument], val_type: _SchemaType) -> EValue:
"""Emits a list type.
Emits the list stored in val. If the list is of Tensors, Optionals, or Ints the emitted list
is boxed, otherwise the values are constant at runtime and stored inline.
NOTE: When symbool and symfloat are supported bool and float lists will be stored boxed.
"""
if isinstance(val_type, torch.BoolType):
return EValue(BoolList(typing.cast(List[bool], val)))
if isinstance(val_type, torch.IntType):
return self._emit_int_list(val)
if isinstance(val_type, torch.FloatType):
return EValue(DoubleList(typing.cast(List[float], val)))
if isinstance(val_type, torch.TensorType):
values = []
for v in val:
assert isinstance(v, _AbstractValue)
self._internal_assert_emitter(
v.tensor is not None,
self.node,
"AbstractValue corresponding to tensor type doesn't contain tensor value",
)
values.append(v.id)
return EValue(TensorList(values))
if isinstance(val_type, torch.OptionalType):
# refine further
actual_type = val_type.getElementType()
if isinstance(actual_type, torch.TensorType):
vals = []
for v in val:
if v is None:
vals.append(-1)
else:
assert isinstance(v, _AbstractValue)
vals.append(v.id)
return EValue(OptionalTensorList(vals))
raise ExportError(
ExportErrorType.NOT_SUPPORTED, f"Unknown list type: {val_type}"
)
def _get_allocation_info(self, spec: TensorSpec) -> AllocationDetails:
"""Returns the allocation info for a given TensorSpec."""
self._internal_assert_emitter(
isinstance(spec.mem_id, int) and spec.mem_id >= 0,
self.node,
f"Non-const tensor should be an activation tensor: mem_id {spec.mem_id}",
)
self._internal_assert_emitter(
isinstance(spec.mem_offset, int) and spec.mem_offset >= 0,
self.node,
f"Non-const tensor should be an activation tensor: mem_offset {spec.mem_offset}",
)
try:
allocation_info = make_allocation_info(spec.mem_id, spec.mem_offset)
except AddressSpaceOverflowException as e:
raise InternalError(
self._emit_node_specific_error(
self.node,
(
f"{e}\nHint: If you are using a memory pass based on dynamic shape bounds, "
f"such as ConstraintBasedSymShapeEvalPass, this may be the cause of an "
f"unbacked SymInt with its upper bound lazily set to 2^64-1 (uint64 max) "
"during torch.export()."
),
)
)
return allocation_info
def _save_to_external_constant_map(
self,
fqn: str,
buffer_idx: int,
constant_tag: str,
) -> None:
"""
Saves external constant to the map.
"""
# buffer data should be in the external_constant_buffer already.
assert buffer_idx < len(self.program_state.external_constant_buffer)
if constant_tag not in self.program_state.external_constant_map:
self.program_state.external_constant_map[constant_tag] = {}
self.program_state.external_constant_map[constant_tag][fqn] = buffer_idx
def _save_new_const_tensor(
self,
spec: TensorSpec,
buffer_data: bytes,
hashed: str,
allocation_info: Optional[AllocationDetails] = None,
constant_tag: Optional[str] = None,
) -> int:
"""Saves a new constant tensor to the constant buffer and returns the buffer idx"""
self.program_state.allocated_specs.append(spec)
# +1 because the first buffer location is reserved.
# Update buffer_idx to point to the end of the list where we are adding the new buffer.
buffer = Buffer(storage=buffer_data)
# Tensor is stored outside of the PTE file.
if (
spec.extra_tensor_info is not None
and spec.extra_tensor_info.fully_qualified_name is not None
and spec.extra_tensor_info.location == TensorDataLocation.EXTERNAL
):
assert (
constant_tag is not None
), "Constant tag is not set for external tensor"
# TODO (#7633): Handle case where we have both mutable and non mutable weights that we want to put in the same external file.
# We will need to create 2 segments in that case, but it'll be a bit until we see this case. LLM finetuning will probably require this.
buffer_idx = len(self.program_state.external_constant_buffer)
self.program_state.external_constant_hash[hashed] = buffer_idx
self.program_state.external_constant_buffer.append(buffer_data)
self._save_to_external_constant_map(
spec.extra_tensor_info.fully_qualified_name, buffer_idx, constant_tag
)
# Tensor is mutable with initial state. Place into mutable segment
elif allocation_info:
buffer_idx = len(self.program_state.mutable_buffer)
self.program_state.cached_spec_mutable_hash_values[hashed] = buffer_idx
self.program_state.mutable_buffer.append(buffer)
# Tensor is stored in the PTE file.
else:
buffer_idx = len(self.program_state.constant_buffer)
self.program_state.cached_spec_hash_values[hashed] = buffer_idx
self.program_state.constant_buffer.append(buffer)
return buffer_idx
def _tensor_spec_to_evalue(
self, spec: TensorSpec, constant_tag: Optional[str] = None
) -> EValue:
"""Constructs an EValue from the given TensorSpec."""
allocation_info = None
buffer_idx = 0
# Need to memory plan
# Some users set mem_id on all tensors and then rely on the
# default algos to set offsets, so need to check both.
if spec.mem_id is not None and spec.mem_offset is not None:
# Tensor is an activation.
allocation_info = self._get_allocation_info(spec)
# Tensor is either a constant tensor, or a mutable tensor with an initial state.
if spec.const:
# Tensor with a blob we need to serialize. May not actually be constant at runtime
# if it's a weight with an associated gradient.
spec_array_type = (
ctypes.c_char * typing.cast(torch.UntypedStorage, spec.storage).nbytes()
)
buffer_data = (
bytes(
ctypes.cast(
typing.cast(torch.UntypedStorage, spec.storage).data_ptr(),
ctypes.POINTER(spec_array_type),
).contents
)
if spec.allocated_memory != 0
else b""
)
hashed = hashlib.sha256(buffer_data).hexdigest()
if allocation_info and spec.extra_tensor_info is None:
buffer_idx = self.program_state.cached_spec_mutable_hash_values.get(
hashed, -1
)
elif (
spec.extra_tensor_info is not None
and spec.extra_tensor_info.location == TensorDataLocation.EXTERNAL
):
buffer_idx = self.program_state.external_constant_hash.get(hashed, -1)
if buffer_idx != -1:
# This constant already exists in the external_constant_buffer,
# And doesn't need to be duplicated. However, the fqn is unique
# and should be added. ie, we have the case: fqn0->data, fqn1->data.
# When buffer_idx == 1, the data is new and added with
# `_save_new_const_tensor` below.
assert spec.extra_tensor_info.fully_qualified_name is not None
assert constant_tag is not None
self._save_to_external_constant_map(
spec.extra_tensor_info.fully_qualified_name,
buffer_idx,
constant_tag,
)
else:
buffer_idx = self.program_state.cached_spec_hash_values.get(hashed, -1)
# Haven't seen this constant before.
if buffer_idx == -1:
buffer_idx = self._save_new_const_tensor(
spec, buffer_data, hashed, allocation_info, constant_tag
)
if spec.const and spec.nbytes() != len(buffer_data):
raise InternalError(
self._emit_node_specific_error(
self.node,
f"Tensor spec has buffer of size {len(buffer_data)}, but expected nbytes of {spec.nbytes()}",
)
)
# For constant tensors, allocation_info = None.
return EValue(make_tensor_value(buffer_idx, allocation_info, spec))
def _get_list_tuple_jit_type(
self, val: Union[Tuple[_Argument], List[_Argument]]
) -> _SchemaType:
"""Returns the JIT type for the given python type."""
assert isinstance(
val, (list, tuple)
), f"Input to _get_list_tuple_jit_type was expected to be an instance of a list or tuple but received {type(val)}"
is_tensor_type = all(
isinstance(v, _AbstractValue) and v.tensor is not None for v in val
)
if is_tensor_type:
return torch.TensorType.get()
elif isinstance(val[0], int):
return torch.IntType.get()
elif isinstance(val[0], bool):
return torch.BoolType.get()
elif isinstance(val[0], float):
return torch.FloatType.get()
raise InternalError(
self._emit_node_specific_error(
self.node,
"Couldn't determine JitType for list/tuple of elements. Only supports int, float, bool, and Tensor.",
)
)
def _constant_to_evalue( # noqa: C901
self,
val: _Argument,
val_type: Optional[_SchemaType],
) -> EValue:
"""Converts a constant value to an EValue.
Returns an EValue given the Python representation and JIT type. On common paths there should
always be a JIT type provided. Users can pass in a None to infer the JIT type but this
should never be the default case due to the existence of container types.
"""
if val is None:
return EValue(Null())
if isinstance(val, (list, tuple)):
# Refine Optional[List[T]] -> List[T] This works because if the val was None it would
# have converted to Null before this function call.
if val_type is None:
val_type = torch.ListType(
self._get_list_tuple_jit_type(val) # pyre-ignore
)
if isinstance(val_type, torch.OptionalType):
val_type = val_type.getElementType()
assert isinstance(val_type, torch.ListType)
return self._emit_list(
typing.cast(List[_Argument], val),
typing.cast(_SchemaType, val_type.getElementType()),
)
if isinstance(val, float):
return EValue(Double(val))
if isinstance(val, bool):
return EValue(Bool(val))
if isinstance(val, int):
return EValue(Int(val))
if isinstance(val, str):
return EValue(String(val))
if isinstance(val, torch.dtype):
return EValue(Int(scalar_type_enum(val)))
if isinstance(val, torch.layout):
return EValue(Int(layout_enum(val)))
if isinstance(val, torch.memory_format):
try:
return EValue(Int(memory_format_enum(val)))
except KeyError:
raise InternalError(
self._emit_node_specific_error(
self.node,
f"Tensor has a memory_format that is unsupported in ExecuTorch: {val}",
)
)
if isinstance(val, torch.Tensor):
raise ExportError(
ExportErrorType.NOT_SUPPORTED,
self._emit_node_specific_error(
self.node,
"constant_to_evalue should not be encountering constant tensors, they should be emitted through other codepaths.",
),
)
raise ExportError(
ExportErrorType.NOT_SUPPORTED,
self._emit_node_specific_error(
self.node, f"Unsupported constant type: {type(val).__name__}"
),
)
def _emit_evalue(self, val: EValue) -> _AbstractValue:
"""Writes an EValue to the emitter state.
Given an Evalue, adds it to the emitter_state's values table, and returns the AbstractValue
representing it.
"""
self.emitter_state.values.append(val)
tensor = val.val if isinstance(val.val, Tensor) else None
return _AbstractValue(len(self.emitter_state.values) - 1, tensor)
def _emit_spec(self, spec: ValueSpec) -> _EmitterValue:
"""Given the provided spec constructs the corresponding EValue from it and then emits it."""
def _process(spec: LeafValueSpec) -> _AbstractValue:
if isinstance(spec, (list, tuple)):
raise InternalError(
self.emit_node_specific_error(
self.node,
"Node spec should be either non-nested container or a scalar object",
)
)
# ScalarSpec can theoretically be handled fine, but it shouldn't be appearing so rather
# than handle it, assert that it isn't supposed to be present. In the future if it has a
# reason to appear we can relax this assert.
self._internal_assert_emitter(
isinstance(spec, TensorSpec),
self.node,
f"Invalid node spec expected TensorSpec received {spec}",
)
ret = self._emit_evalue(self._tensor_spec_to_evalue(spec)) # pyre-ignore
self.emitter_state.spec2id_dict[spec] = ret.id # pyre-ignore
return ret
return pytree.tree_map(_process, spec)
def _merge_chain(self, chain: Chain) -> None:
"""Merges the chain generated from subgraphs (like those originating from control flow) back
into the main program chain."""
self.chain.instructions.extend(chain.instructions)
def _emit_cond(
self,
args: Tuple[_Argument, ...],
subemitter_binding_output_values: Optional[List[_AbstractValue]],
) -> List[_AbstractValue]:
"""Emits control_flow.cond.
Converts the higher order op into jumps and inlines the submodules of the true and false
branches. Control flow can be nested. The general emitted structure is: <Jump Instruction> -
decides which branch to take <True Branch> <Jump Instruction> - jumps to End Of Cond after
executing true branch <False Branch> <End Of Cond>
"""
pred, true_branch, false_branch, inputs = args
# Emit the true branch.
assert isinstance(true_branch, torch.fx.GraphModule)
true_branch_emitter = _Emitter(
true_branch,
self.emitter_state,
self.program_state,
instruction_start_offset=self.instruction_start_offset
+ len(self.chain.instructions)
+ 1,
binding_input_values=typing.cast(List[_AbstractValue], inputs),
binding_output_values=subemitter_binding_output_values,
)
true_branch_emitter.run()
# Emit the jump.
assert isinstance(pred, _AbstractValue)
jf_instruction_to_skip_true = Instruction(
JumpFalseCall(
cond_value_index=pred.id,
destination_instruction=self.instruction_start_offset
+ len(self.chain.instructions)
+ len(true_branch_emitter.chain.instructions)
# This jump instruction should point at instruction that is after all instructions
# for the true branch. The reason we add 2 is because we need to account for this
# instruction we are creating right now and the jump instruction that true branch
# will create.
+ 2,
)
)
# Insert the branch picking jump instruction to the main chain.
self.chain.instructions.append(jf_instruction_to_skip_true)
# Now that we created the true branch instructions, we move them to the main chain.
self._merge_chain(true_branch_emitter.chain)
# emit false branch
assert isinstance(false_branch, torch.fx.GraphModule)
false_branch_emitter = _Emitter(
false_branch,
self.emitter_state,
self.program_state,
instruction_start_offset=self.instruction_start_offset
+ len(self.chain.instructions)
+ 1,
binding_input_values=typing.cast(List[_AbstractValue], inputs),
binding_output_values=subemitter_binding_output_values,
)
false_branch_emitter.run()
# We bake in constant False because this will trigger the instruction to jump over all false
# branch instructions and point at the start of instruction right after control flow.
value = self._emit_evalue(EValue(Bool(False)))
jf_instruction_to_skip_false = Instruction(
JumpFalseCall(
cond_value_index=value.id,
destination_instruction=self.instruction_start_offset
+ len(self.chain.instructions)
+ len(false_branch_emitter.chain.instructions)
+ 1,
)
)
self.chain.instructions.append(jf_instruction_to_skip_false)
self._merge_chain(false_branch_emitter.chain)
return subemitter_binding_output_values
def _emit_map(
self,
args: Tuple[_Argument, ...],
subemitter_binding_output_values: List[_AbstractValue],
) -> List[_AbstractValue]:
"""Emits torch.map.
Converts the higher order op into a loop constructed from jump instructions and primitive
int operations. A concat-like custom op is also injected into the submodule code to handle
the construction of the map output.
Below is what the input graph that is provided to emit_map looks like. class
TestMapCond(torch.nn.Module): def __init__(self):
super().__init__()
def forward(self, x,y):
return control_flow.map(map_fn, x, y)
Corresponding graph: def forward(self, arg0_1, arg1_1):
submodule_0 = self.submodule_0 map_1 = torch.ops.higher_order.map_impl(submodule_0, arg0_1, arg1_1);
submodule_0 = arg0_1 = arg1_1 = None return [map_1]
submodule_0: def forward(self, arg0_1, arg1_1):
add_tensor = torch.ops.aten.add.Tensor(arg0_1, arg1_1); arg0_1 = arg1_1 = None return
add_tensor
Post the transformations done by emit_map this is what the submodule program looks like. def
forward(self, arg0_1, arg1_1):
sym_size = torch.ops.aten.sym_size(arg0_1) # Emitter creates a variable here to track
iteration index select_copy_tensor = torch.ops.aten.select(arg0_1, 0, iteration_index)
add_tensor = torch.ops.aten.add.Tensor(select_copy_tensor, arg1_1); arg0_1 = arg1_1 =
None output_of_map = torch.ops.executorch.prim.et_copy_index(output_of_map, add_tensor,
iteration_index) iteration_index = torch.ops.executorch.prim.add.int(iteration_index, 1,
iteration_index) done_bool = torch.ops.executorch.prim.eq.int(iteration_index, sym_size,
done_bool) # Emitter inserts a instruction here, if done_bool == False jump to
selcect_copy op # if not continue. return add_tensor
"""
assert isinstance(
subemitter_binding_output_values, (list, tuple)
), f"Expect a list for subemitter_binding_output_values for map. Got {subemitter_binding_output_values}."
if len(subemitter_binding_output_values) != 1:
raise RuntimeError(
f"Multiple outputs are not supported. Got {len(subemitter_binding_output_values)}."
)
f, mapped_args, inputs = args
assert isinstance(mapped_args, (list, tuple))
num_mapped_args: int = len(mapped_args)
if num_mapped_args != 1:
raise RuntimeError(
f"Emitting map with more than one mapped args is not supported. Got {num_mapped_args}."
)
x = mapped_args[0]
assert isinstance(f, torch.fx.GraphModule)
# Generate the EValue that we will use as our iterator index to keep track of which
# iteration we are currently on.
iter_idx = self._emit_evalue(EValue(Int(0)))
# Generate the kernel call that will output the number of iterations we need to run for this
# input tensor.
op_index, op = self._get_operator(
name="aten::sym_size",
overload="int",
)
sym_size = self._emit_evalue(EValue(Int(0)))
kernel = Instruction(
KernelCall(
op_index=op_index,
args=[x.id, self._emit_evalue(EValue(Int(0))).id, sym_size.id],
)
)
self.chain.instructions.append(kernel)
# This kernel call will slice the input tensor along the index specified in iter_idx to
# generate the input slice on which this iteration will be working on.
op_index, op = self._get_operator(
name="aten::select_copy",
overload="int_out",
)
# This select copy has to output to the tensor which is the input placeholder to the map
# sub-graph. That placeholder isn't allocated an EValue id until the map emitter is run, so
# we temporarily store -1 until the map emitter is run during which the placeholder will be
# allocated an EValue id. After the map emitter is run we will retrieve that id and replace
# the -1's.
kernel = Instruction(
KernelCall(
op_index=op_index,
args=[
x.id,
self._emit_evalue(EValue(Int(0))).id,
iter_idx.id,
-1, # input_tensor_value.id,
-1, # input_tensor_value.id,
],
)
)
# Store the index of this instruction as it will be where we will jump back to after the end
# of an iteration.
jump_to_instruction = self.instruction_start_offset + len(
self.chain.instructions
)
self.chain.instructions.append(kernel)
# Emit the map operator submodule.
map_emitter = _Emitter(
f,
self.emitter_state,
self.program_state,
instruction_start_offset=self.instruction_start_offset
+ len(self.chain.instructions),
# Only the first input is a placeholder, rest of the inputs are args to the map fn.
binding_input_values=[-1, *inputs],
binding_output_values=subemitter_binding_output_values,
)
map_emitter.run()
# Merge all the instructions from the map submodule.
self._merge_chain(map_emitter.chain)
# Get rid of the return instruction emitted by the map subemitter.
self.chain.instructions.pop()
# At the end of each submodule emit we insert a move call that moves the output of the
# submodule to a deterministic EValue, which is especially useful for if/else branches where
# we want the output of either branch to be in the same EValue, but we don't need a move
# here as our custom op executorch_prim::et_copy_index which is inserted later does that
# for us.
# Now that the map emitter has finished running retrieve the input placeholder EValue id and
# update the select_copy kernel call to output to those id's.
kernel.instr_args.args[-1] = map_emitter.binding_input_values[0].id
kernel.instr_args.args[-2] = kernel.instr_args.args[-1]
self._internal_assert_emitter(
len(map_emitter.concrete_output_ids) == 1,
self.node,
"Map should return only one element",
)
# Here we call the custom op, specially added for the map operator. The output of this
# iteration will be appended to the accumulator tensor that we are maintaining. This
# accumulator tensor is the actual output of the map submodule.
op_index, op = self._get_operator(
name="executorch_prim::et_copy_index",
overload="tensor",
)
kernel = Instruction(
KernelCall(
op_index,
args=[
subemitter_binding_output_values[0].id,
map_emitter.concrete_output_ids[0].id,
iter_idx.id,
],
)
)
self.chain.instructions.append(kernel)
# Increment iter_idx to mark that we have completed an iteration.
op_index, op = self._get_operator(
name="executorch_prim::add",
overload="Scalar",
)
kernel = Instruction(
KernelCall(
op_index,
args=[iter_idx.id, self._emit_evalue(EValue(Int(1))).id, iter_idx.id],
)
)
self.chain.instructions.append(kernel)
jump_bool_value = self._emit_evalue(EValue(Bool(False)))
# Generate the kernel call to check whether or not we have completed all the iterations. If
# not jump back to the select_copy instruction that we generated at the beginning of this
# section.
op_index, op = self._get_operator(
name="executorch_prim::eq",
overload="Scalar",
)
kernel = Instruction(
KernelCall(
op_index,
args=[iter_idx.id, sym_size.id, jump_bool_value.id],
)
)
self.chain.instructions.append(kernel)
jf_beginning_loop = Instruction(
JumpFalseCall(
cond_value_index=jump_bool_value.id,
destination_instruction=jump_to_instruction,
)
)
self.chain.instructions.append(jf_beginning_loop)
# Reset iter_idx in case we plan to run the model again.
op_index, op = self._get_operator(
name="executorch_prim::sub",
overload="Scalar",
)
kernel = Instruction(
KernelCall(
op_index,
args=[iter_idx.id, sym_size.id, iter_idx.id],
)
)
self.chain.instructions.append(kernel)
return subemitter_binding_output_values
def _emit_scan(
self,
args: Tuple[_Argument, ...],
subemitter_binding_output_values: List[_AbstractValue],
) -> List[_AbstractValue]:
"""Emits torch.scan.
Converts the higher order scan op into a loop constructed from jump instructions
and primitive operations. Scan differs from map in that it maintains a carry state
that evolves across iterations.
Scan signature: scan(combine_fn, init, xs, additional_inputs)
- combine_fn: GraphModule that takes (carry, x_slice, *additional_inputs)
and returns (next_carry, y_slice)
- init: Initial carry state (list of tensors)
- xs: Input tensors to scan over (list of tensors, scanned along dim 0)
- additional_inputs: Additional arguments passed to combine_fn
Output: (final_carry, stacked_ys)
- final_carry: The carry state after the last iteration
- stacked_ys: All y outputs stacked along dim 0
Memory Layout:
- carry_outputs (subemitter_binding_output_values[:num_carry]):
Working carry buffers, initialized from init, updated each iteration
- y_outputs (subemitter_binding_output_values[num_carry:]):
Pre-allocated stacked output buffers, filled via et_copy_index
The combine_fn writes to its own temporary output buffers (concrete_output_ids).
After each iteration:
1. Copy combine_fn's carry output -> carry_outputs (for next iteration)
2. et_copy_index(y_outputs, combine_fn's y output, iter_idx)
This explicit copy approach is used because in-place op.out(x, out=x) is unsafe.
"""
combine_fn, init, xs, additional_inputs = args
assert isinstance(subemitter_binding_output_values, (list, tuple)), (
f"Expected list for subemitter_binding_output_values. "
f"Got {type(subemitter_binding_output_values).__name__}: "
f"{subemitter_binding_output_values}."
)
assert isinstance(combine_fn, torch.fx.GraphModule)
assert isinstance(init, (list, tuple))
assert isinstance(xs, (list, tuple))
assert isinstance(additional_inputs, (list, tuple))
num_carry = len(init)
num_xs = len(xs)
carry_outputs = list(subemitter_binding_output_values[:num_carry])
y_outputs = list(subemitter_binding_output_values[num_carry:])
if num_xs < 1: