|
7 | 7 | ## - compiling the model using ROOT Cling |
8 | 8 | ## - run the code and optionally compare with ONNXRuntime |
9 | 9 | ## |
| 10 | +## The PyTorch export and ROOT's SOFIE parser are both linked against protobuf, |
| 11 | +## but usually against different versions, so loading them in the same process |
| 12 | +## leads to a symbol clash. We therefore run the PyTorch -> ONNX export in a |
| 13 | +## separate Python process and only import ROOT afterwards. |
10 | 14 | ## |
11 | 15 | ## \macro_code |
12 | 16 | ## \macro_output |
13 | 17 | ## \author Lorenzo Moneta |
14 | 18 |
|
| 19 | +import os |
| 20 | +import sys |
| 21 | +import subprocess |
15 | 22 |
|
16 | | -import contextlib |
| 23 | +import numpy as np |
| 24 | +import ROOT |
| 25 | + |
| 26 | + |
| 27 | +# The PyTorch export, as a small standalone script run in its own process. |
| 28 | +# It takes the model name as its only argument and writes <modelName>.onnx. |
| 29 | +EXPORT_SCRIPT = r""" |
| 30 | +import sys |
17 | 31 | import inspect |
18 | 32 | import warnings |
| 33 | +import contextlib |
19 | 34 |
|
20 | | -import numpy as np |
21 | | -import ROOT |
22 | 35 | import torch |
23 | 36 | import torch.nn as nn |
24 | 37 |
|
| 38 | +modelName = sys.argv[1] |
| 39 | +
|
25 | 40 |
|
26 | 41 | @contextlib.contextmanager |
27 | 42 | def expect_warning(category, message): |
28 | | - """Silence a known third-party warning and raise if it stops firing. |
| 43 | + # Silence a known third-party warning and raise if it stops firing. |
29 | 44 |
|
30 | | - Notifies us to drop the workaround once the upstream library is fixed. |
31 | | - """ |
| 45 | + # Notifies us to drop the workaround once the upstream library is fixed. |
32 | 46 | with warnings.catch_warnings(record=True) as caught: |
33 | 47 | warnings.simplefilter("always") |
34 | 48 | yield |
@@ -97,8 +111,11 @@ def filtered_kwargs(func, **candidate_kwargs): |
97 | 111 | return modelFile |
98 | 112 | except TypeError: |
99 | 113 | print("Cannot export model from pytorch to ONNX - with version ", torch.__version__) |
100 | | - print("Skip tutorial execution") |
101 | | - exit() |
| 114 | + # leave no .onnx behind: which the parent process treats as a RuntimeError |
| 115 | + sys.exit() |
| 116 | +
|
| 117 | +CreateAndTrainModel(modelName) |
| 118 | +""" |
102 | 119 |
|
103 | 120 |
|
104 | 121 | def ParseModel(modelFile, verbose=False): |
@@ -127,12 +144,17 @@ def ParseModel(modelFile, verbose=False): |
127 | 144 |
|
128 | 145 |
|
129 | 146 | ################################################################### |
130 | | -## Step 1 : Create and Train model |
| 147 | +## Step 1 : Create and train the model, export it to ONNX |
| 148 | +## (done in a separate process to avoid the protobuf clash) |
131 | 149 | ################################################################### |
132 | 150 |
|
133 | 151 | # use an arbitrary modelName |
134 | 152 | modelName = "LinearModel" |
135 | | -modelFile = CreateAndTrainModel(modelName) |
| 153 | +modelFile = modelName + ".onnx" |
| 154 | + |
| 155 | +subprocess.run([sys.executable, "-c", EXPORT_SCRIPT, modelName]) |
| 156 | +if not os.path.exists(modelFile): |
| 157 | + raise RuntimeError("ONNX model could not be exported") |
136 | 158 |
|
137 | 159 |
|
138 | 160 | ################################################################### |
|
0 commit comments