diff --git a/coremltools/converters/mil/frontend/torch/test/test_passes.py b/coremltools/converters/mil/frontend/torch/test/test_passes.py index 4401ebbfb..7ffe2ae55 100644 --- a/coremltools/converters/mil/frontend/torch/test/test_passes.py +++ b/coremltools/converters/mil/frontend/torch/test/test_passes.py @@ -17,6 +17,7 @@ from ..torchir_passes import ( flatten_graph_input_values, flatten_graph_output_values, + remove_getattr_nodes, transform_inplace_ops, ) import coremltools as ct @@ -113,6 +114,63 @@ def test_flatten_input_values(): np.testing.assert_equal(graph.nodes[1].outputs[0], graph.nodes[2].inputs[0]) + @staticmethod + def test_remove_getattr_nodes_with_output_buffer(): + # Regression test for #2538: when a model directly returns a buffer, + # the corresponding getattr appears in graph.outputs and the pass + # used to raise. It now replaces the getattr with a constant node + # holding the buffer value so conversion can proceed. + params = { + "buf_a": np.array([1.0, 2.0, 3.0], dtype=np.float32), + "buf_b": np.array([4.0, 5.0], dtype=np.float32), + } + graph_nodes = [ + InternalTorchIRNode(inputs=[], outputs=["buf_a"], kind="getattr", name="buf_a"), + InternalTorchIRNode(inputs=[], outputs=["buf_b"], kind="getattr", name="buf_b"), + ] + graph = InternalTorchIRGraph( + nodes=graph_nodes, + params=params, + inputs=OrderedDict(), + outputs=["buf_a", "buf_b"], + ) + + remove_getattr_nodes(graph) + + np.testing.assert_equal(len(graph.nodes), 2) + for node in graph.nodes: + np.testing.assert_equal(node.kind, "constant") + np.testing.assert_array_equal(graph.nodes[0].attr["value"], params["buf_a"]) + np.testing.assert_array_equal(graph.nodes[1].attr["value"], params["buf_b"]) + # Original output names are preserved. + np.testing.assert_equal(graph.nodes[0].outputs, ["buf_a"]) + np.testing.assert_equal(graph.nodes[1].outputs, ["buf_b"]) + + + @staticmethod + def test_remove_getattr_nodes_drops_intermediate(): + # Sanity check: a getattr node that is *not* in graph.outputs should + # still be dropped (the consuming op handler reads from graph.params). + params = {"weight": np.array([1.0], dtype=np.float32)} + graph_nodes = [ + InternalTorchIRNode(inputs=[], outputs=["weight"], kind="getattr", name="weight"), + InternalTorchIRNode( + inputs=["x", "weight"], outputs=["y"], kind="mul", name="y" + ), + ] + graph = InternalTorchIRGraph( + nodes=graph_nodes, + params=params, + inputs=OrderedDict([("x", torch.rand(1))]), + outputs=["y"], + ) + + remove_getattr_nodes(graph) + + np.testing.assert_equal(len(graph.nodes), 1) + np.testing.assert_equal(graph.nodes[0].kind, "mul") + + @staticmethod def test_flatten_output_values(): graph = _build_flattening_test_graph() diff --git a/coremltools/converters/mil/frontend/torch/torchir_passes.py b/coremltools/converters/mil/frontend/torch/torchir_passes.py index cf784cdf0..18d5ab2b6 100644 --- a/coremltools/converters/mil/frontend/torch/torchir_passes.py +++ b/coremltools/converters/mil/frontend/torch/torchir_passes.py @@ -6,6 +6,8 @@ from collections import OrderedDict, defaultdict from typing import Dict, Optional +import numpy as np + from coremltools import _logger as logger from .internal_graph import InternalTorchIRGraph, InternalTorchIRNode @@ -229,30 +231,52 @@ def forward(self, x): node.model_hierarchy = cached_model_hierarchy[child_ops[node.name][0]] -def remove_getattr_nodes(graph: InternalTorchIRGraph) -> None: +def remove_getattr_nodes( + graph: InternalTorchIRGraph, + params: Optional[Dict[str, "np.ndarray"]] = None, +) -> None: """ - Remove the getattr nodes in the graph + Remove the getattr nodes from the graph. + + A getattr node typically references a buffer / parameter that is consumed + by another op; the consuming op handler reads the value from the + surrounding graph's params dict, so dropping the getattr node is safe. + However, when a model directly returns a buffer (e.g. forward returns + `self.my_constant`), the getattr appears in the graph outputs. In that + case, replace the getattr with a constant node holding the buffer value + so the conversion does not crash. """ - getattr_nodes = [] + if params is None: + params = graph.params + new_nodes = [] for node in graph.nodes: for block in node.blocks: - remove_getattr_nodes(block) + remove_getattr_nodes(block, params=params) if node.kind == "getattr": - getattr_nodes.append(node) + if node.name in graph.outputs: + if node.name not in params: + raise RuntimeError( + "{} appears in the graph outputs but its value was not " + "found in the graph params.".format(node.name) + ) + # Replace the getattr with a constant node carrying the value. + new_nodes.append( + InternalTorchIRNode( + kind="constant", + inputs=[], + outputs=node.outputs, + name=node.name, + attr={"value": params[node.name]}, + ) + ) else: new_nodes.append(node) - # check the getattr nodes not in the outputs - for node in getattr_nodes: - if node.name in graph.outputs: - raise RuntimeError("{} should not be in the graph outputs.".format(node.name)) - - # remove the getattr nodes graph.nodes = new_nodes