Skip to content

Commit ce08cbb

Browse files
committed
[tmva][sofie] Run PyTorch ONNX export in a separate process in tutorial
The tutorial exported a PyTorch model to ONNX and then parsed it with SOFIE in the same process. `torch.onnx` and ROOT's SOFIE ONNX parser are both linked against protobuf, but generally against different versions, so loading them into one process causes a symbol clash and aborts. Move the model creation, training and ONNX export into a small standalone script that runs in its own Python process via subprocess, before ROOT is imported, so the two protobuf copies are never loaded together. The parent detects success by the presence of the generated .onnx file and raises a RuntimeError if it is missing.
1 parent 74ef037 commit ce08cbb

1 file changed

Lines changed: 32 additions & 10 deletions

File tree

tutorials/machine_learning/TMVA_SOFIE_ONNX.py

Lines changed: 32 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -7,28 +7,42 @@
77
## - compiling the model using ROOT Cling
88
## - run the code and optionally compare with ONNXRuntime
99
##
10+
## The PyTorch export and ROOT's SOFIE parser are both linked against protobuf,
11+
## but usually against different versions, so loading them in the same process
12+
## leads to a symbol clash. We therefore run the PyTorch -> ONNX export in a
13+
## separate Python process and only import ROOT afterwards.
1014
##
1115
## \macro_code
1216
## \macro_output
1317
## \author Lorenzo Moneta
1418

19+
import os
20+
import sys
21+
import subprocess
1522

16-
import contextlib
23+
import numpy as np
24+
import ROOT
25+
26+
27+
# The PyTorch export, as a small standalone script run in its own process.
28+
# It takes the model name as its only argument and writes <modelName>.onnx.
29+
EXPORT_SCRIPT = r"""
30+
import sys
1731
import inspect
1832
import warnings
33+
import contextlib
1934
20-
import numpy as np
21-
import ROOT
2235
import torch
2336
import torch.nn as nn
2437
38+
modelName = sys.argv[1]
39+
2540
2641
@contextlib.contextmanager
2742
def expect_warning(category, message):
28-
"""Silence a known third-party warning and raise if it stops firing.
43+
# Silence a known third-party warning and raise if it stops firing.
2944
30-
Notifies us to drop the workaround once the upstream library is fixed.
31-
"""
45+
# Notifies us to drop the workaround once the upstream library is fixed.
3246
with warnings.catch_warnings(record=True) as caught:
3347
warnings.simplefilter("always")
3448
yield
@@ -97,8 +111,11 @@ def filtered_kwargs(func, **candidate_kwargs):
97111
return modelFile
98112
except TypeError:
99113
print("Cannot export model from pytorch to ONNX - with version ", torch.__version__)
100-
print("Skip tutorial execution")
101-
exit()
114+
# leave no .onnx behind: which the parent process treats as a RuntimeError
115+
sys.exit()
116+
117+
CreateAndTrainModel(modelName)
118+
"""
102119

103120

104121
def ParseModel(modelFile, verbose=False):
@@ -127,12 +144,17 @@ def ParseModel(modelFile, verbose=False):
127144

128145

129146
###################################################################
130-
## Step 1 : Create and Train model
147+
## Step 1 : Create and train the model, export it to ONNX
148+
## (done in a separate process to avoid the protobuf clash)
131149
###################################################################
132150

133151
# use an arbitrary modelName
134152
modelName = "LinearModel"
135-
modelFile = CreateAndTrainModel(modelName)
153+
modelFile = modelName + ".onnx"
154+
155+
subprocess.run([sys.executable, "-c", EXPORT_SCRIPT, modelName])
156+
if not os.path.exists(modelFile):
157+
raise RuntimeError("ONNX model could not be exported")
136158

137159

138160
###################################################################

0 commit comments

Comments
 (0)