Skip to content

Commit fc94d80

Browse files
committed
Replace getattr with constant when buffer is a graph output
`remove_getattr_nodes` raised when a model directly returned a registered buffer / parameter (e.g. `forward(self): return self.buf`): RuntimeError: <name> should not be in the graph outputs. The pre-existing `flatten_graph_output_values` pass already lifts buffer-returning getattr nodes into the graph outputs, so the bare `raise` had no useful effect — there is no other op that lowers them. Replace each output-producing getattr with a `constant` node carrying the buffer value (read from `graph.params`). The downstream `constant` op handler then materializes it as a const var. Intermediate getattr nodes are still dropped as before, since their consumers read directly from `graph.params`. Fixes #2538.
1 parent 5256644 commit fc94d80

2 files changed

Lines changed: 93 additions & 11 deletions

File tree

coremltools/converters/mil/frontend/torch/test/test_passes.py

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
from ..torchir_passes import (
1818
flatten_graph_input_values,
1919
flatten_graph_output_values,
20+
remove_getattr_nodes,
2021
transform_inplace_ops,
2122
)
2223
import coremltools as ct
@@ -113,6 +114,63 @@ def test_flatten_input_values():
113114
np.testing.assert_equal(graph.nodes[1].outputs[0], graph.nodes[2].inputs[0])
114115

115116

117+
@staticmethod
118+
def test_remove_getattr_nodes_with_output_buffer():
119+
# Regression test for #2538: when a model directly returns a buffer,
120+
# the corresponding getattr appears in graph.outputs and the pass
121+
# used to raise. It now replaces the getattr with a constant node
122+
# holding the buffer value so conversion can proceed.
123+
params = {
124+
"buf_a": np.array([1.0, 2.0, 3.0], dtype=np.float32),
125+
"buf_b": np.array([4.0, 5.0], dtype=np.float32),
126+
}
127+
graph_nodes = [
128+
InternalTorchIRNode(inputs=[], outputs=["buf_a"], kind="getattr", name="buf_a"),
129+
InternalTorchIRNode(inputs=[], outputs=["buf_b"], kind="getattr", name="buf_b"),
130+
]
131+
graph = InternalTorchIRGraph(
132+
nodes=graph_nodes,
133+
params=params,
134+
inputs=OrderedDict(),
135+
outputs=["buf_a", "buf_b"],
136+
)
137+
138+
remove_getattr_nodes(graph)
139+
140+
np.testing.assert_equal(len(graph.nodes), 2)
141+
for node in graph.nodes:
142+
np.testing.assert_equal(node.kind, "constant")
143+
np.testing.assert_array_equal(graph.nodes[0].attr["value"], params["buf_a"])
144+
np.testing.assert_array_equal(graph.nodes[1].attr["value"], params["buf_b"])
145+
# Original output names are preserved.
146+
np.testing.assert_equal(graph.nodes[0].outputs, ["buf_a"])
147+
np.testing.assert_equal(graph.nodes[1].outputs, ["buf_b"])
148+
149+
150+
@staticmethod
151+
def test_remove_getattr_nodes_drops_intermediate():
152+
# Sanity check: a getattr node that is *not* in graph.outputs should
153+
# still be dropped (the consuming op handler reads from graph.params).
154+
params = {"weight": np.array([1.0], dtype=np.float32)}
155+
graph_nodes = [
156+
InternalTorchIRNode(inputs=[], outputs=["weight"], kind="getattr", name="weight"),
157+
InternalTorchIRNode(
158+
inputs=["x", "weight"], outputs=["y"], kind="mul", name="y"
159+
),
160+
]
161+
graph = InternalTorchIRGraph(
162+
nodes=graph_nodes,
163+
params=params,
164+
inputs=OrderedDict([("x", torch.rand(1))]),
165+
outputs=["y"],
166+
)
167+
168+
remove_getattr_nodes(graph)
169+
170+
np.testing.assert_equal(len(graph.nodes), 1)
171+
np.testing.assert_equal(graph.nodes[0].kind, "mul")
172+
173+
116174
@staticmethod
117175
def test_flatten_output_values():
118176
graph = _build_flattening_test_graph()

coremltools/converters/mil/frontend/torch/torchir_passes.py

Lines changed: 35 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@
66
from collections import OrderedDict, defaultdict
77
from typing import Dict, Optional
88

9+
import numpy as np
10+
911
from coremltools import _logger as logger
1012

1113
from .internal_graph import InternalTorchIRGraph, InternalTorchIRNode
@@ -229,30 +231,52 @@ def forward(self, x):
229231
node.model_hierarchy = cached_model_hierarchy[child_ops[node.name][0]]
230232

231233

232-
def remove_getattr_nodes(graph: InternalTorchIRGraph) -> None:
234+
def remove_getattr_nodes(
235+
graph: InternalTorchIRGraph,
236+
params: Optional[Dict[str, "np.ndarray"]] = None,
237+
) -> None:
233238
"""
234-
Remove the getattr nodes in the graph
239+
Remove the getattr nodes from the graph.
240+
241+
A getattr node typically references a buffer / parameter that is consumed
242+
by another op; the consuming op handler reads the value from the
243+
surrounding graph's params dict, so dropping the getattr node is safe.
244+
However, when a model directly returns a buffer (e.g. forward returns
245+
`self.my_constant`), the getattr appears in the graph outputs. In that
246+
case, replace the getattr with a constant node holding the buffer value
247+
so the conversion does not crash.
235248
"""
236249

237-
getattr_nodes = []
250+
if params is None:
251+
params = graph.params
252+
238253
new_nodes = []
239254

240255
for node in graph.nodes:
241256

242257
for block in node.blocks:
243-
remove_getattr_nodes(block)
258+
remove_getattr_nodes(block, params=params)
244259

245260
if node.kind == "getattr":
246-
getattr_nodes.append(node)
261+
if node.name in graph.outputs:
262+
if node.name not in params:
263+
raise RuntimeError(
264+
"{} appears in the graph outputs but its value was not "
265+
"found in the graph params.".format(node.name)
266+
)
267+
# Replace the getattr with a constant node carrying the value.
268+
new_nodes.append(
269+
InternalTorchIRNode(
270+
kind="constant",
271+
inputs=[],
272+
outputs=node.outputs,
273+
name=node.name,
274+
attr={"value": params[node.name]},
275+
)
276+
)
247277
else:
248278
new_nodes.append(node)
249279

250-
# check the getattr nodes not in the outputs
251-
for node in getattr_nodes:
252-
if node.name in graph.outputs:
253-
raise RuntimeError("{} should not be in the graph outputs.".format(node.name))
254-
255-
# remove the getattr nodes
256280
graph.nodes = new_nodes
257281

258282

0 commit comments

Comments
 (0)