forked from pytorch/executorch
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtracer.py
More file actions
803 lines (680 loc) · 27.3 KB
/
tracer.py
File metadata and controls
803 lines (680 loc) · 27.3 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
# pyre-strict
import copy
import json
import traceback
from contextlib import contextmanager
from dataclasses import asdict, dataclass
from typing import (
Any,
Callable,
Dict,
Generator,
Iterable,
List,
Optional,
Set,
Tuple,
Union,
)
import executorch.extension.pytree as ex_pytree
import torch
import torch._dynamo as torchdynamo
import torch.fx as fx
import torch.fx._pytree as fx_pytree
import torch.utils._pytree as pytree
from executorch.exir.common import (
extract_out_arguments,
format_schema_name,
no_dispatch,
setting_python_recursive_limit,
)
from executorch.exir.error import ExportError, ExportErrorType, InternalError
from executorch.exir.graph_module import LeafValue
from executorch.exir.operator.convert import is_out_variant
from executorch.exir.operator.util import _QUANT_PRIMITIVES
from executorch.exir.types import ValueSpec
from torch._C import _EnableTorchFunction, DisableTorchFunctionSubclass # @manual
from torch._decomp import get_decompositions
from torch._dynamo.guards import Guard
from torch._functorch.eager_transforms import _maybe_unwrap_functional_tensor
from torch.export import default_decompositions
from torch.func import functionalize
from torch.fx.operator_schemas import normalize_function
from torch.utils._pytree import TreeSpec
from typing_extensions import TypeAlias
Value: TypeAlias = Union[
LeafValue,
Tuple["Value", ...],
List["Value"],
Dict[str, "Value"],
]
torchdynamo_enabled = False
def get_stacktrace() -> List[Dict[str, str]]:
"""
Get the current stacktrace (between trace() and __torch_dispatch__())
Include the filename, function name, line number, and source code from the
start of the function to the given instruction.
Return:
A list of stacktraces for each instruction along with the source code
context surrounding each instruction
"""
stacktrace = traceback.extract_stack()
# The stacktrace typically looks like this:
#
# 1. I stack frames from the top level runner (e.g., the
# test suite runner)
# 2. J frames in executorch/exir/tracer.py setting up tracing
# (call this INIT_EXIR)
# 3. K frames in user model code (this is what we want to save!)
# 4. 1 frame in executorch/exir/tracer.py __torch_function__
# returning to tracer (call this TRACE_EXIR)
# 5. H frames in executorch/exir/tracer.py AND torch/_tensor.py
# doing all of the internal tracer handling
#
# The PyE tests assert that executorch/exir/tracer.py never shows
# up in the user provided stack traces, so we must oblige them.
#
# Assumptions:
# - Reentrant tracing is not a thing. Thus, the first time
# executorch/exir/tracer.py shows up in the trace, we know
# THAT is the point at which we start tracing. (An alternative
# is that the tracer entry point could record the stack trace
# at this time, but I didn't do this.)
#
# Our plan is to do a miniature stack machine traversing these
# stack machines.
# Remove parts before the trace function and parts after entering
# __torch_dispatch__. Defaults to returning the entire stack trace.
init_exir_end = 0
trace_exir_start = None
# A miniature state machine, referring to the frame segments described
# above. The locations are closed-open interval.
FIND_INIT_EXIR_START, FIND_INIT_EXIR_END, FIND_TRACE_EXIR_START = range(3)
state = FIND_INIT_EXIR_START
for i, frame in enumerate(stacktrace):
if state == FIND_INIT_EXIR_START:
if "executorch/exir/tracer.py" in frame.filename:
state = FIND_INIT_EXIR_END
elif state == FIND_INIT_EXIR_END:
if "executorch/exir/tracer.py" not in frame.filename:
init_exir_end = i
state = FIND_TRACE_EXIR_START
elif state == FIND_TRACE_EXIR_START:
if "executorch/exir/tracer.py" in frame.filename:
trace_exir_start = i
break
stacktrace = stacktrace[init_exir_end:trace_exir_start]
# Get the source code from the errored line to it
contexts: List[str] = []
for s in stacktrace:
try:
with open(s.filename) as file:
# pyre-fixme[6]: For 1st param expected `Union[SupportsTrunc, bytes,
# str, SupportsInt, SupportsIndex]` but got `Optional[int]`.
lineno = int(s.lineno)
# Get the source code 5 lines above/below the current instruction
file_contents = [
str(index + 1) + line for index, line in enumerate(file.readlines())
]
file_contents_above = "".join(
file_contents[max(lineno - 5, 0) : lineno]
)
file_contents_below = "".join(
file_contents[lineno : min(lineno + 5, len(file_contents))]
)
context = (
file_contents_above
+ "^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n"
+ file_contents_below
)
contexts.append(context)
except FileNotFoundError:
contexts.append("<unknown file: unknown line>")
# torch.fx stack preservation logic expects strings to
# be passed around. Working with dictionary is lot easier
# to convert to string and vice versa.
frames: List[Dict[str, str]] = []
for i, frame in enumerate(stacktrace):
frames.append(
{
"filename": str(frame.filename),
"lineno": str(frame.lineno),
"name": str(frame.name),
"line": str(frame.line),
"context": contexts[i],
}
)
return frames
def unwrap_functional(t: torch.Tensor) -> torch.Tensor:
assert isinstance(t, torch.Tensor)
return _maybe_unwrap_functional_tensor(t, reapply_views=False)
def unwrap_proxy(t: LeafValue) -> Union[LeafValue, torch.fx.Proxy]:
if not isinstance(t, torch.Tensor):
return t
t = unwrap_functional(t)
return t.proxy if isinstance(t, PythonTensor) else t
def single_return(
output: LeafValue,
proxy: torch.fx.Proxy,
wrapper: Callable[..., LeafValue],
) -> LeafValue:
if isinstance(output, torch.Tensor):
return wrapper(output, proxy)
return output
def tree_return(
outputs: Value,
proxy: torch.fx.Proxy,
wrapper: Callable[..., LeafValue],
meta_type: Callable[..., Iterable[ValueSpec]] = tuple,
) -> Value:
i: int = 0
def wrap(o: LeafValue) -> LeafValue:
nonlocal i
ret = single_return(o, proxy[i], wrapper)
i += 1
return ret
return pytree.tree_map(wrap, outputs)
class DummyProxy:
def __init__(self) -> None:
class DummyNode:
def __init__(self):
self.meta = {}
self.node = DummyNode()
def __getitem__(self, key: str) -> "DummyProxy":
return DummyProxy()
class PythonTensor(torch.Tensor):
"""
A wrapper tensor subclass used in the DispatchTracer to keep track of
proxies to construct the FX graph.
Wrapping something in PythonTensor implicitly detaches gradients. If
something required grad, we will collect it as if it were a leaf. A
consequence of detaching in this way is you need to maintain a parameter
cache when translating tensors into PythonTensor, so you don't create
multiple copies of a gradient (they are aliased, but they would count as
independent leaves). An alternate strategy would be to avoid implicitly
detaching and instead "catch" gradients as they exit the PythonTensor
boundary.
"""
__slots__ = ["proxy", "is_immutable"]
@staticmethod
def __new__(
cls, elem: torch.Tensor, proxy: torch.fx.Proxy, is_immutable: bool = False
) -> torch.Tensor:
# assert not elem.requires_grad or not torch.is_grad_enabled()
r = torch.Tensor._make_subclass(cls, elem, elem.requires_grad)
assert isinstance(r, PythonTensor)
r.is_immutable: bool = is_immutable
r.update_proxy(proxy)
return r
def update_proxy(self, proxy: torch.fx.Proxy) -> None:
self.proxy = proxy
def __repr__(self, *, tensor_contents: None = None) -> str:
with no_dispatch():
return f"PythonTensor({self.as_subclass(torch.Tensor)})"
@classmethod
def __torch_function__(
cls,
# pyre-ignore: Missing parameter annotation [2]
func,
# pyre-ignore: Missing parameter annotation [2]
types,
args: Tuple[Value, ...] = (),
kwargs: Optional[Dict[str, Value]] = None,
) -> Value:
if kwargs is None:
kwargs = {}
if torch.is_inference_mode_enabled():
if func is torch.nn.functional.layer_norm:
args, kwargs = normalize_function(func, args, kwargs) # pyre-fixme[23]
input, normalized_shape = args
normalized_shape = list(normalized_shape)
return cls.__torch_dispatch__(
torch.ops.aten.layer_norm.default,
types,
(input, normalized_shape),
kwargs,
)
elif func is torch.nn.functional.linear:
return cls.__torch_dispatch__(
torch.ops.aten.linear.default, types, args, kwargs
)
with DisableTorchFunctionSubclass():
return func(*args, **kwargs)
@classmethod
def __torch_dispatch__( # noqa: C901
cls,
func_overload: torch._ops.OpOverload,
# pyre-ignore: Missing parameter annotation [2]
types,
args: Tuple[Value, ...] = (),
kwargs: Optional[Dict[str, Value]] = None,
) -> Value:
"""
This function is invoked every time an aten operation is called.
Args:
func_overload: The function that was called that invoked this
torch_dispatch call
types:
args: Arguments that were passed into the function. Each argument
has type PythonTensor.
kwargs: Keyword arguments that were passed into the function. Each
argument has type PythonTensor.
"""
func = func_overload.overloadpacket
kwargs = kwargs or {}
if is_out_variant(func._qualified_op_name, func_overload._overloadname):
out_args = extract_out_arguments(func_overload._schema, kwargs)
out_args_iter = [out_args] if not isinstance(out_args, list) else out_args
for out_arg_name, out_arg_val in out_args_iter:
if isinstance(out_arg_val, PythonTensor) and out_arg_val.is_immutable:
raise RuntimeError(
"Immutable tensor `{}` is potentially getting modified by {}".format(
out_arg_name, format_schema_name(func_overload._schema)
)
)
# pyre-fixme[16]: Module `pytree` has no attribute `tree_map`.
proxy_args = ex_pytree.tree_map(unwrap_proxy, args)
# pyre-fixme[16]: Module `pytree` has no attribute `tree_map`.
proxy_kwargs = ex_pytree.tree_map(unwrap_proxy, kwargs)
# Get the output of the function
g = _EnableTorchFunction()
try:
proxy_out = (
func_overload(*proxy_args, **proxy_kwargs)
if DispatchTracer.get() or torchdynamo_enabled
# Disable node creation when no tracer is active.
else DummyProxy()
)
finally:
del g
with no_dispatch():
real_out = func_overload(*args, **kwargs)
# Kind of a hacky way to test if an op is in-place or not
if func.__name__[-1] == "_" and func.__name__[0] != "_":
if isinstance(args[0], PythonTensor):
args[0].proxy = proxy_out
if not torch.fx.traceback.has_preserved_node_meta():
proxy_out.node.meta["stack_trace"] = json.dumps(get_stacktrace())
# Wrap the output tensors with the PythonTensor subclass to propagate to
# future tracing
def wrap_with_proxy(e: LeafValue, proxy: torch.fx.Proxy) -> LeafValue:
# Some ops (like native_batch_norm_backward) return undefined tensors that get
# converted into None in python.
# As the function signature expects tensors, if we directly return these None
# tensors back to C++, we'll error.
if e is None:
e = torch.empty(())
if isinstance(e, torch.Tensor):
return PythonTensor(e, proxy)
# Inplace and out-variant ops may return one of their arguments, which is already
# a PythonTensor. In this case, we need to update the PythonTensor's associated
# proxy to the newly created proxy.
if isinstance(e, PythonTensor):
e.update_proxy(proxy)
return e
return e
retval = None
if not isinstance(real_out, (list, tuple)):
retval = single_return(real_out, proxy_out, wrap_with_proxy)
else:
retval = tree_return(real_out, proxy_out, wrap_with_proxy, type(real_out))
return retval
@contextmanager
def using_tracer(tracer: Optional["DispatchTracer"]) -> Generator[None, None, None]:
"""
Set the "current" global tracer within the scope of using_tracer
context manager.
Since various things we want to capture today with torch_dispatch
does not "trap" into dispatcher really (for example, cond() and
shape()), we need a separate singleton tracer exposed to user space
in addition to Dispatcher to trigger graph capturing.
"""
global TRACER
TRACER, prev = tracer, TRACER
try:
yield
finally:
TRACER = prev
class DispatchTracer(fx.Tracer):
def __init__(self) -> None:
super().__init__()
self.root: torch.nn.Module = torch.nn.Module()
self.tensor_attrs: Dict[torch.Tensor, str] = {}
self.submodules: Dict[fx.GraphModule, str] = {}
def call_module(
self,
m: torch.nn.Module,
forward: Callable[..., Value],
args: Tuple[Value, ...],
kwargs: Dict[str, Value],
) -> Value:
return forward(*args, **kwargs)
def _module_getattr(
self, attr: str, attr_val: Value, parameter_proxy_cache: Dict[str, torch.Tensor]
) -> Value:
if isinstance(attr_val, torch.nn.Parameter):
for n, p in self.root.named_parameters():
if attr_val is p:
if n not in parameter_proxy_cache:
proxy = self.create_proxy("get_attr", n, (), {})
parameter_proxy_cache[n] = PythonTensor(attr_val, proxy)
return parameter_proxy_cache[n]
return attr_val
return attr_val
def create_arg(self, a: Value) -> torch.fx.Node: # noqa: C901
if isinstance(a, torch.nn.Parameter):
for n, p in self.root.named_parameters():
if a is p:
return self.create_node("get_attr", n, (), {})
qualname: Optional[str] = None
if not qualname:
i = 0
while True:
qualname = f"_param_constant{i}"
if not hasattr(self.root, qualname):
break
i += 1
setattr(self.root, qualname, a)
return self.create_node("get_attr", qualname, (), {})
if isinstance(a, torch.Tensor):
qualname: Optional[str] = self.tensor_attrs.get(a)
if not qualname:
i = 0
while True:
qualname = f"_tensor_constant{i}"
if not hasattr(self.root, qualname):
break
i += 1
self.tensor_attrs[a] = qualname
self.root.register_buffer(qualname, a)
return self.create_node("get_attr", qualname, (), {})
# higher-order operator
if isinstance(a, fx.GraphModule):
if a not in self.submodules:
name_submodule = f"submodule_{len(self.submodules)}"
self.root.add_module(name_submodule, a)
self.submodules[a] = name_submodule
return self.create_node("get_attr", self.submodules[a], (), {})
return super().create_arg(a) # pyre-fixme[7]
@staticmethod
def get() -> "DispatchTracer":
return TRACER
def trace( # pyre-fixme[14,15]
self,
root: Callable[..., Value],
concrete_args: Tuple[Value, ...] = (),
in_spec: Optional[TreeSpec] = None,
) -> Value:
"""
Traces the given graph module.
"""
with using_tracer(self):
return self._trace(root, concrete_args=concrete_args, in_spec=in_spec)
def _trace(
self,
root: Callable[..., Value],
concrete_args: Tuple[Value, ...],
in_spec: Optional[TreeSpec],
) -> Value:
self.root = torch.nn.Module()
root_fn = root
tracer_cls = getattr(self, "__class__", None)
self.graph = fx.Graph(tracer_cls=tracer_cls)
# Don't support module, so tensor_attrs is always empty
self.tensor_attrs = {}
# Wrap all inputs as a PythonTensor subclass and insert them into the FX
# graph as placeholder nodes
def wrap(arg: Value, i: int) -> Value:
placeholder = self.create_proxy("placeholder", f"ph_{i}", (), {})
if isinstance(arg, torch.Tensor):
return PythonTensor(arg, placeholder, is_immutable=True)
else:
# torch._assert(
# placeholder == arg,
# f"ph_{i} has been specialized to have value {arg}",
# )
return arg
tree_args = [wrap(arg, i) for i, arg in enumerate(concrete_args)]
if in_spec:
tree_args = pytree.tree_unflatten(tree_args, in_spec)
tree_out = root_fn(*tree_args)
out_args, _ = pytree.tree_flatten(tree_out)
def unwrap(out: LeafValue) -> Union[LeafValue, torch.fx.Proxy]:
# it's legit for a model to return a list of items some of which
# are None
if out is None:
return None
if not isinstance(out, torch.Tensor):
raise TypeError(
f"Expect model to return torch.Tensor, got type: '{type(out)}' (value: {out})."
)
return unwrap_proxy(out)
returns = [unwrap(out) for out in out_args]
return_annotation = None
# some ops like torch.sub doesn't have annotations
if hasattr(root_fn, "__annotations__"):
return_annotation = root_fn.__annotations__.get("return", None)
self.create_proxy(
"output",
"output",
(returns,),
{},
type_expr=return_annotation,
)
self.submodule_paths = None
return tree_out
TRACER: Optional[DispatchTracer] = None
TORCHDYNAMO_ENABLED: bool = False
@contextmanager
def using_dynamo(val: bool) -> Generator[None, None, None]:
global TORCHDYNAMO_ENABLED
TORCHDYNAMO_ENABLED, prev = val, TORCHDYNAMO_ENABLED
try:
yield
finally:
TORCHDYNAMO_ENABLED = prev
def flattened_dispatch_trace(
f: Callable[..., Value],
args: Tuple[LeafValue, ...],
guards: Set[Guard],
in_spec: Optional[TreeSpec] = None,
enable_functionalization: bool = True,
) -> Tuple[torch.fx.GraphModule, Value]:
if not isinstance(args, tuple):
raise TypeError(f"Expecting 'args' to be a tuple, got: {type(args)}")
tracer = DispatchTracer()
if enable_functionalization:
f = functionalize(f, remove="mutations_and_views")
tree_out = tracer.trace(f, concrete_args=args, in_spec=in_spec)
name = type(f).__name__ if isinstance(f, torch.nn.Module) else f.__name__
gm = torch.fx.GraphModule(tracer.root, tracer.graph, name)
return (gm, tree_out)
@dataclass
class ExirDynamoConfig:
"""
Manage Exir-specific configurations of Dynamo.
"""
allow_rnn: bool = True
verbose: bool = True
assume_static_by_default: bool = False
def flatten_output(gm: torch.fx.GraphModule) -> None:
"""
Modifies the output nodes in the submodules to return the result
as a flattened list. This keeps it consistent with the result of
EXIR's tracer
"""
for node in reversed(gm.graph.nodes):
if node.op == "output":
assert len(node.args) == 1
outputs = node.args[0]
returns, _ = pytree.tree_flatten(outputs)
node.args = (returns,)
return
raise RuntimeError(f"Could not find an output node in {gm.graph}")
def _default_decomposition_table(
_use_old_decomp_table=False,
) -> Dict[torch._ops.OpOverload, Callable[..., Value]]:
if _use_old_decomp_table:
decomp_opset = [
torch.ops.aten.log_sigmoid_forward,
torch.ops.aten.ones,
torch.ops.aten.arange.default,
torch.ops.aten.arange.start,
torch.ops.aten.transpose,
]
return get_decompositions(decomp_opset) # pyre-fixme[7]
decomps = default_decompositions()
# Add edge specific decompositions
additional_decomp_ops = [
# TODO: Eventually this op should be added to the core decompo table, and will not
# need to be added here.
torch.ops.aten.linalg_vector_norm.default,
]
additional_decomps = get_decompositions(additional_decomp_ops)
decomps.update(additional_decomps)
# pyre-fixme[7]: Expected `Dict[OpOverload, typing.Callable[..., executorch.exir....
never_decompose = []
never_decompose.extend(_QUANT_PRIMITIVES)
for op in never_decompose:
decomps.pop(op, None)
return decomps # pyre-fixme[7]
def dynamo_trace(
f: Callable[..., Value],
# pyre-ignore
args: Tuple[Any, ...],
aten_graph: bool,
tracing_mode: str = "real",
dynamo_config: Optional[ExirDynamoConfig] = None,
# pyre-ignore
dynamic_shapes: Optional[List[Any]] = None,
_use_old_decomp_table: bool = False,
) -> Tuple[torch.fx.GraphModule, Set[Guard]]:
"""
TODO: Once we fully migrate to torchdynamo frontend, we will remove
this config option alltogether. For now, it helps with quick
experiments with playing around with TorchDynamo
"""
if dynamo_config is None:
dynamo_config = ExirDynamoConfig()
with torchdynamo.config.patch(
asdict(dynamo_config)
), setting_python_recursive_limit(2000):
torchdynamo.reset()
try:
# TODO merge executorch functionalization with official
# functionalization
# pyre-fixme[7]: Expected `Tuple[GraphModule, Set[Guard]]` but got
# `ExportResult`.
return torchdynamo.export(
f,
aten_graph=aten_graph,
tracing_mode=tracing_mode,
assume_static_by_default=dynamo_config.assume_static_by_default,
decomposition_table=(
_default_decomposition_table(_use_old_decomp_table)
if aten_graph
else None
),
dynamic_shapes=dynamic_shapes,
)(
*copy.deepcopy(args),
)
except torchdynamo.exc.Unsupported as exc:
raise ExportError(
ExportErrorType.NOT_SUPPORTED,
"The user code is using a feature we don't support. "
"Please try torchdynamo.explain() to get possible the reasons",
) from exc
except Exception as exc:
raise InternalError(
"torchdynamo internal error occurred. Please see above stacktrace"
) from exc
def dispatch_trace(
f: Callable[..., Value],
args: Tuple[Value, ...],
) -> torch.fx.GraphModule:
"""
Executes a given callable `f` with a given tuple of arguments. During
execution, Tensor operations are recorded in a fx.GraphModule, which is then
returned.
Args:
f: A `nn.Module` or a Python function that implements an ML program.
args: A tuple of arguments of any type to be used as inputs for the tracing run.
Returns:
EXIR contained in a fx.GraphModule
"""
trace_func = f
guards = set()
if TORCHDYNAMO_ENABLED:
# Copying args is safer in case downstream implementations of trace_func mutate them
trace_func, guards = dynamo_trace(trace_func, args, False)
# Copying args is safer in case downstream implementations of trace_func mutate them
trace_args, in_spec = pytree.tree_flatten(args)
in_args = copy.deepcopy(tuple(trace_args))
gm, tree_out = flattened_dispatch_trace(
trace_func,
in_args,
guards,
in_spec,
enable_functionalization=False,
)
_, out_spec = pytree.tree_flatten(tree_out)
# pyre-fixme[16]: `GraphModule` has no attribute `in_spec`.
gm.in_spec = in_spec
# pyre-fixme[16]: `GraphModule` has no attribute `out_spec`.
gm.out_spec = out_spec
# TODO (tmanlaibaatar) This is bit clowny, but our
# dispatch_trace sometimes creates unused node that
# breaks functionalization. it seems too much trouble
# to fix it properly since dispatch_trace will be deprecated soon.
# Basically dispatch_trace struggles on:
# def f(x: torch.Tensor) -> torch.Tensor:
# return torch.ones(6, dtype=x.dtype)
changed = gm.graph.eliminate_dead_code()
if changed:
gm.recompile()
in_args = copy.deepcopy(tuple(trace_args))
assert callable(gm)
# This wrapper is used for preserving the stacktrace
# during second round of tracing.
# pyre-ignore
def graph_with_interpreter(*args):
try:
args = fx_pytree.tree_flatten_spec(args, gm.in_spec) # type: ignore[assignment]
except Exception:
_, received_spec = pytree.tree_flatten(args)
raise RuntimeError(
"Trying to flatten user inputs with exported input tree spec: \n"
f"{gm.in_spec}\n"
"but actually got inputs with tree spec of: \n"
f"{received_spec}"
)
with torch.fx.traceback.preserve_node_meta():
res = gm(*args)
if gm.out_spec is not None:
try:
res = pytree.tree_unflatten(res, gm.out_spec)
except Exception:
_, received_spec = pytree.tree_flatten(res)
raise RuntimeError(
"Trying to flatten user outputs with exported output tree spec: \n"
f"{gm.out_spec}\n"
"but actually got outputs with tree spec of: \n"
f"{received_spec}"
)
return res
gm, tree_out = flattened_dispatch_trace(
graph_with_interpreter,
in_args,
guards,
in_spec,
enable_functionalization=True,
)
gm.in_spec = in_spec
gm.out_spec = out_spec
return gm