Skip to content

Commit 3a8d719

Browse files
authored
Fix NeutronConverterManager pickle error with forkserver multiprocessing (#19855) (#19855)
Summary: Refactors convert_unsafe() to pass picklable dict instead of unpicklable module/C++ objects, adds TypeError to fallback handler (both fbcode + xplat copies) Differential Revision: D106689031
1 parent 40b0a35 commit 3a8d719

3 files changed

Lines changed: 69 additions & 19 deletions

File tree

backends/nxp/backend/neutron_converter_manager.py

Lines changed: 29 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -15,13 +15,29 @@
1515
)
1616

1717

18-
def convert_unsafe(neutron_converter, tflite_model, cctx, queue):
18+
def _build_compilation_context(compilation_opts):
19+
"""Build a CompilationContext from a plain dict of options."""
20+
cctx = neutron_converter.CompilationContext()
21+
cctx.targetOpts = neutron_converter.getNeutronTarget(compilation_opts["target"])
22+
cctx.compilationOpts.minNumOpsPerGraph = compilation_opts["minNumOpsPerGraph"]
23+
cctx.compilationOpts.excludeGraphPasses = compilation_opts["excludeGraphPasses"]
24+
cctx.compilationOpts.fetchConstantsToSRAM = compilation_opts["fetchConstantsToSRAM"]
25+
cctx.compilationOpts.dumpKernelSelectionCode = compilation_opts[
26+
"dumpKernelSelectionCode"
27+
]
28+
if hasattr(cctx.compilationOpts, "useNewFlowNeutronC"):
29+
cctx.compilationOpts.useNewFlowNeutronC = compilation_opts["useNewFlowNeutronC"]
30+
return cctx
31+
32+
33+
def convert_unsafe(tflite_model, compilation_opts, queue):
1934
"""
20-
Run neutron_converter on given tflite_model with compilation context cctx.
35+
Run neutron_converter on given tflite_model with the provided compilation options.
2136
This routine is supposed to run in a separate process.
2237
If properly finished, the output queue contains the converted model,
2338
otherwise the neutron_converter exits and the output queue is empty.
2439
"""
40+
cctx = _build_compilation_context(compilation_opts)
2541
model_converted = neutron_converter.convertModel(list(tflite_model), cctx)
2642
queue.put(model_converted)
2743

@@ -84,16 +100,14 @@ def convert(
84100
# Neutron converter crashes if we provide invalid target -> verify.
85101
self.verify_target(target)
86102

87-
cctx = neutron_converter.CompilationContext()
88-
cctx.targetOpts = neutron_converter.getNeutronTarget(target)
89-
cctx.compilationOpts.minNumOpsPerGraph = 1
90-
cctx.compilationOpts.excludeGraphPasses = (
91-
"HoistSliceAboveTranspose,MergeTranspose"
92-
)
93-
cctx.compilationOpts.fetchConstantsToSRAM = fetch_constants_to_sram
94-
cctx.compilationOpts.dumpKernelSelectionCode = self.dump_kernel_selection_code
95-
if hasattr(cctx.compilationOpts, "useNewFlowNeutronC"):
96-
cctx.compilationOpts.useNewFlowNeutronC = use_new_flow_neutron_c
103+
compilation_opts = {
104+
"target": target,
105+
"minNumOpsPerGraph": 1,
106+
"excludeGraphPasses": "HoistSliceAboveTranspose,MergeTranspose",
107+
"fetchConstantsToSRAM": fetch_constants_to_sram,
108+
"dumpKernelSelectionCode": self.dump_kernel_selection_code,
109+
"useNewFlowNeutronC": use_new_flow_neutron_c,
110+
}
97111

98112
# Try to use multiprocessing for isolation, but fall back to direct execution
99113
# if the environment doesn't support it (e.g., in sandcastle/build environments)
@@ -104,7 +118,7 @@ def convert(
104118

105119
process = multiprocessing.Process(
106120
target=convert_unsafe,
107-
args=(neutron_converter, tflite_model, cctx, queue),
121+
args=(tflite_model, compilation_opts, queue),
108122
)
109123
process.start()
110124
process.join() # waits until the subprocess is complete
@@ -116,12 +130,13 @@ def convert(
116130

117131
model_converted = queue.get()
118132
process.close()
119-
except (EOFError, OSError) as e:
133+
except (EOFError, OSError, TypeError) as e:
120134
# Multiprocessing failed (likely due to environment restrictions)
121135
# Fall back to direct execution
122136
logging.warning(
123137
f"Multiprocessing not available ({e}), running neutron converter directly"
124138
)
139+
cctx = _build_compilation_context(compilation_opts)
125140
model_converted = neutron_converter.convertModel(list(tflite_model), cctx)
126141
if self.dump_kernel_selection_code:
127142
self._rename_partition_kernel_selection_file(delegation_tag)

backends/nxp/tests/BUCK

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -112,6 +112,20 @@ fbcode_target(_kind = python_pytest,
112112
],
113113
)
114114

115+
fbcode_target(_kind = python_pytest,
116+
name = "test_neutron_converter_manager",
117+
srcs = [
118+
"generic_tests/test_neutron_converter_manager.py",
119+
],
120+
deps = [
121+
"//executorch/backends/nxp:neutron_sdk",
122+
"//executorch/exir:lib",
123+
":executorch_pipeline",
124+
":models",
125+
"fbsource//third-party/pypi/pytest-mock:pytest-mock", # @manual
126+
],
127+
)
128+
115129
fbcode_target(_kind = python_pytest,
116130
name = "test_integration",
117131
srcs = [

backends/nxp/tests/generic_tests/test_neutron_converter_manager.py

Lines changed: 26 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,9 @@
44
# LICENSE file in the root directory of this source tree.
55

66
import multiprocessing
7+
import pickle
78

89
import torch
9-
from eiq_neutron_sdk.neutron_converter.neutron_converter import CompilationContext
1010
from executorch import exir
1111
from executorch.backends.nxp.backend.edge_program_converter import (
1212
EdgeProgramToIRConverter,
@@ -69,7 +69,28 @@ def test_neutron_converter_with_experimental_mlir_flow(mocker):
6969
model, input_shape, use_new_flow_neutron_c=True
7070
).exported_program()
7171

72-
compilation_context = process_spy.call_args.kwargs["args"][2]
73-
assert isinstance(compilation_context, CompilationContext)
74-
if hasattr(compilation_context.compilationOpts, "useNewFlowNeutronC"):
75-
assert compilation_context.compilationOpts.useNewFlowNeutronC
72+
compilation_opts = process_spy.call_args.kwargs["args"][1]
73+
assert isinstance(compilation_opts, dict)
74+
assert compilation_opts["useNewFlowNeutronC"] is True
75+
76+
77+
def test_convert_unsafe_args_are_picklable(mocker):
78+
"""Verify that all args passed to `multiprocessing.Process` are picklable.
79+
80+
The subprocess uses forkserver/spawn in some environments, which requires
81+
all Process args to be serializable via pickle.
82+
"""
83+
model = LinearModule(True)
84+
input_shape = (1, 1, 32, 32)
85+
86+
process_spy = mocker.spy(multiprocessing, "Process")
87+
to_quantized_edge_program(model, input_shape).exported_program()
88+
89+
args = process_spy.call_args.kwargs["args"]
90+
for i, arg in enumerate(args):
91+
try:
92+
pickle.dumps(arg)
93+
except (pickle.PicklingError, TypeError) as e:
94+
raise AssertionError(
95+
f"Process arg at index {i} ({type(arg).__name__}) is not picklable: {e}"
96+
)

0 commit comments

Comments
 (0)