Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 58 additions & 0 deletions coremltools/converters/mil/frontend/torch/test/test_passes.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from ..torchir_passes import (
flatten_graph_input_values,
flatten_graph_output_values,
remove_getattr_nodes,
transform_inplace_ops,
)
import coremltools as ct
Expand Down Expand Up @@ -113,6 +114,63 @@ def test_flatten_input_values():
np.testing.assert_equal(graph.nodes[1].outputs[0], graph.nodes[2].inputs[0])


@staticmethod
def test_remove_getattr_nodes_with_output_buffer():
# Regression test for #2538: when a model directly returns a buffer,
# the corresponding getattr appears in graph.outputs and the pass
# used to raise. It now replaces the getattr with a constant node
# holding the buffer value so conversion can proceed.
params = {
"buf_a": np.array([1.0, 2.0, 3.0], dtype=np.float32),
"buf_b": np.array([4.0, 5.0], dtype=np.float32),
}
graph_nodes = [
InternalTorchIRNode(inputs=[], outputs=["buf_a"], kind="getattr", name="buf_a"),
InternalTorchIRNode(inputs=[], outputs=["buf_b"], kind="getattr", name="buf_b"),
]
graph = InternalTorchIRGraph(
nodes=graph_nodes,
params=params,
inputs=OrderedDict(),
outputs=["buf_a", "buf_b"],
)

remove_getattr_nodes(graph)

np.testing.assert_equal(len(graph.nodes), 2)
for node in graph.nodes:
np.testing.assert_equal(node.kind, "constant")
np.testing.assert_array_equal(graph.nodes[0].attr["value"], params["buf_a"])
np.testing.assert_array_equal(graph.nodes[1].attr["value"], params["buf_b"])
# Original output names are preserved.
np.testing.assert_equal(graph.nodes[0].outputs, ["buf_a"])
np.testing.assert_equal(graph.nodes[1].outputs, ["buf_b"])


@staticmethod
def test_remove_getattr_nodes_drops_intermediate():
# Sanity check: a getattr node that is *not* in graph.outputs should
# still be dropped (the consuming op handler reads from graph.params).
params = {"weight": np.array([1.0], dtype=np.float32)}
graph_nodes = [
InternalTorchIRNode(inputs=[], outputs=["weight"], kind="getattr", name="weight"),
InternalTorchIRNode(
inputs=["x", "weight"], outputs=["y"], kind="mul", name="y"
),
]
graph = InternalTorchIRGraph(
nodes=graph_nodes,
params=params,
inputs=OrderedDict([("x", torch.rand(1))]),
outputs=["y"],
)

remove_getattr_nodes(graph)

np.testing.assert_equal(len(graph.nodes), 1)
np.testing.assert_equal(graph.nodes[0].kind, "mul")


@staticmethod
def test_flatten_output_values():
graph = _build_flattening_test_graph()
Expand Down
46 changes: 35 additions & 11 deletions coremltools/converters/mil/frontend/torch/torchir_passes.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
from collections import OrderedDict, defaultdict
from typing import Dict, Optional

import numpy as np

from coremltools import _logger as logger

from .internal_graph import InternalTorchIRGraph, InternalTorchIRNode
Expand Down Expand Up @@ -229,30 +231,52 @@ def forward(self, x):
node.model_hierarchy = cached_model_hierarchy[child_ops[node.name][0]]


def remove_getattr_nodes(graph: InternalTorchIRGraph) -> None:
def remove_getattr_nodes(
graph: InternalTorchIRGraph,
params: Optional[Dict[str, "np.ndarray"]] = None,
) -> None:
"""
Remove the getattr nodes in the graph
Remove the getattr nodes from the graph.

A getattr node typically references a buffer / parameter that is consumed
by another op; the consuming op handler reads the value from the
surrounding graph's params dict, so dropping the getattr node is safe.
However, when a model directly returns a buffer (e.g. forward returns
`self.my_constant`), the getattr appears in the graph outputs. In that
case, replace the getattr with a constant node holding the buffer value
so the conversion does not crash.
"""

getattr_nodes = []
if params is None:
params = graph.params

new_nodes = []

for node in graph.nodes:

for block in node.blocks:
remove_getattr_nodes(block)
remove_getattr_nodes(block, params=params)

if node.kind == "getattr":
getattr_nodes.append(node)
if node.name in graph.outputs:
if node.name not in params:
raise RuntimeError(
"{} appears in the graph outputs but its value was not "
"found in the graph params.".format(node.name)
)
# Replace the getattr with a constant node carrying the value.
new_nodes.append(
InternalTorchIRNode(
kind="constant",
inputs=[],
outputs=node.outputs,
name=node.name,
attr={"value": params[node.name]},
)
)
else:
new_nodes.append(node)

# check the getattr nodes not in the outputs
for node in getattr_nodes:
if node.name in graph.outputs:
raise RuntimeError("{} should not be in the graph outputs.".format(node.name))

# remove the getattr nodes
graph.nodes = new_nodes


Expand Down