Skip to content

Commit e68412e

Browse files
committed
spell
1 parent a8ab456 commit e68412e

2 files changed

Lines changed: 54 additions & 46 deletions

File tree

_doc/cmds/validate.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -159,7 +159,7 @@ About the exporter 'custom'
159159
+++++++++++++++++++++++++++
160160

161161
It used to investigate issues or scenarios. It is usually very strict
162-
and fails everytime it falls in one unexpected situation.
162+
and fails every time it falls in one unexpected situation.
163163
It call :func:`experimental_experiment.torch_interpreter.to_onnx`.
164164
Some useful environment variables to set before running the command line.
165165

onnx_diagnostic/export/api.py

Lines changed: 53 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,54 @@
33
from .onnx_plug import EagerDirectReplacementWithOnnx
44

55

6+
def get_main_dispatcher(
7+
use_control_flow_dispatcher: bool = False,
8+
onnx_plugs: Optional[List[EagerDirectReplacementWithOnnx]] = None,
9+
) -> "Dispatcher": # noqa: F821
10+
"""
11+
Creates a custom dispatcher for the custom exporter.
12+
"""
13+
from experimental_experiment.torch_interpreter import Dispatcher
14+
15+
if use_control_flow_dispatcher:
16+
from .control_flow_onnx import create_global_dispatcher
17+
18+
control_flow_dispatcher = create_global_dispatcher()
19+
else:
20+
control_flow_dispatcher = None
21+
22+
class MainDispatcher(Dispatcher):
23+
def __init__(self, previous_dispatcher=None):
24+
super().__init__({})
25+
self.previous_dispatcher = previous_dispatcher
26+
27+
@property
28+
def supported(self):
29+
if self.previous_dispatcher:
30+
return set(self.registered_functions) | self.previous_dispatcher.supported
31+
return set(self.registered_functions)
32+
33+
def find_function(self, name: Any):
34+
if self.previous_dispatcher:
35+
find = self.previous_dispatcher.find_function(name)
36+
if find:
37+
return find
38+
return Dispatcher.find_function(self, name)
39+
40+
def find_method(self, name: Any):
41+
if self.previous_dispatcher:
42+
find = self.previous_dispatcher.find_method(name)
43+
if find:
44+
return find
45+
return Dispatcher.find_method(self, name)
46+
47+
main_dispatcher = MainDispatcher(control_flow_dispatcher)
48+
if onnx_plugs:
49+
for plug in onnx_plugs:
50+
main_dispatcher.registered_functions[plug.target_name] = plug.custom_converter()
51+
return main_dispatcher
52+
53+
654
def to_onnx(
755
mod: Union["torch.nn.Module", "torch.fx.GraphModule"], # noqa: F821
856
args: Optional[Sequence["torch.Tensor"]] = None, # noqa: F821
@@ -82,51 +130,11 @@ def to_onnx(
82130
options = exporter_kwargs.pop("options", None)
83131
if options is None:
84132
options = OptimizationOptions(patterns="default+onnxruntime")
85-
if onnx_plugs or use_control_flow_dispatcher:
86-
from experimental_experiment.torch_interpreter import Dispatcher
87-
88-
if use_control_flow_dispatcher:
89-
from .control_flow_onnx import create_global_dispatcher
90-
91-
control_flow_dispatcher = create_global_dispatcher()
92-
else:
93-
control_flow_dispatcher = None
94-
95-
class MainDispatcher(Dispatcher):
96-
def __init__(self, previous_dispatcher=None):
97-
super().__init__({})
98-
self.previous_dispatcher = previous_dispatcher
99-
100-
@property
101-
def supported(self):
102-
if self.previous_dispatcher:
103-
return (
104-
set(self.registered_functions) | self.previous_dispatcher.supported
105-
)
106-
return set(self.registered_functions)
107-
108-
def find_function(self, name: Any):
109-
if self.previous_dispatcher:
110-
find = self.previous_dispatcher.find_function(name)
111-
if find:
112-
return find
113-
return Dispatcher.find_function(self, name)
114-
115-
def find_method(self, name: Any):
116-
if self.previous_dispatcher:
117-
find = self.previous_dispatcher.find_method(name)
118-
if find:
119-
return find
120-
return Dispatcher.find_method(self, name)
121-
122-
main_dispatcher = MainDispatcher(control_flow_dispatcher)
123-
if onnx_plugs:
124-
for plug in onnx_plugs:
125-
main_dispatcher.registered_functions[plug.target_name] = (
126-
plug.custom_converter()
127-
)
128-
else:
129-
main_dispatcher = None
133+
main_dispatcher = (
134+
get_main_dispatcher(use_control_flow_dispatcher, onnx_plugs)
135+
if onnx_plugs or use_control_flow_dispatcher
136+
else None
137+
)
130138

131139
return _to_onnx(
132140
mod,

0 commit comments

Comments
 (0)