Skip to content

Commit 98aba50

Browse files
committed
fix(relax): handle ONNX ScatterElements reduction
1 parent 2c76c79 commit 98aba50

2 files changed

Lines changed: 131 additions & 9 deletions

File tree

python/tvm/relax/frontend/onnx/onnx_frontend.py

Lines changed: 31 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1159,6 +1159,20 @@ def _impl_v11(cls, bb, inputs, attr, params):
11591159
raise ValueError("Scatter is deprecated in ONNX 11")
11601160

11611161

1162+
def _get_onnx_reduction(attr, valid_reductions: list[str]):
1163+
reduction = attr.get("reduction", None)
1164+
reduction = reduction or b"update"
1165+
if isinstance(reduction, bytes):
1166+
reduction = reduction.decode("utf-8")
1167+
reduction = "update" if reduction == "none" else reduction
1168+
if reduction not in valid_reductions:
1169+
raise ValueError(
1170+
f"Only {valid_reductions} reductions are supported, but got {reduction}"
1171+
)
1172+
1173+
return reduction
1174+
1175+
11621176
class ScatterElements(OnnxOpConverter):
11631177
"""Convert an onnx ScatterElements node into an equivalent Relax expression."""
11641178

@@ -1167,21 +1181,29 @@ def _impl_v11(cls, bb, inputs, attr, params):
11671181
axis = attr.get("axis", 0)
11681182
return relax.op.scatter_elements(inputs[0], inputs[1], inputs[2], axis=axis)
11691183

1184+
@classmethod
1185+
def _impl_v16(cls, bb, inputs, attr, params):
1186+
axis = attr.get("axis", 0)
1187+
reduction = _get_onnx_reduction(attr, ["update", "add", "mul"])
1188+
return relax.op.scatter_elements(
1189+
inputs[0], inputs[1], inputs[2], axis=axis, reduction=reduction
1190+
)
1191+
1192+
@classmethod
1193+
def _impl_v18(cls, bb, inputs, attr, params):
1194+
axis = attr.get("axis", 0)
1195+
reduction = _get_onnx_reduction(attr, ["update", "add", "mul", "min", "max"])
1196+
return relax.op.scatter_elements(
1197+
inputs[0], inputs[1], inputs[2], axis=axis, reduction=reduction
1198+
)
1199+
11701200

11711201
class ScatterND(OnnxOpConverter):
11721202
"""Convert an onnx ScatterND node into an equivalent Relax expression."""
11731203

11741204
@staticmethod
11751205
def _reduction_check(attr, valid_reductions: list[str]):
1176-
reduction = attr.get("reduction", None)
1177-
reduction = reduction or b"update"
1178-
reduction = reduction.decode("utf-8")
1179-
reduction = "update" if reduction == "none" else reduction
1180-
assert reduction in valid_reductions, (
1181-
f"Only {valid_reductions} reductions are supported, but {reduction} is gotten"
1182-
)
1183-
1184-
return reduction
1206+
return _get_onnx_reduction(attr, valid_reductions)
11851207

11861208
@classmethod
11871209
def _impl_v11(cls, bb, inputs, attr, params):

tests/python/relax/test_frontend_onnx.py

Lines changed: 100 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1023,6 +1023,106 @@ def test_scatter(axis: int, name: str, opset: int):
10231023
check_correctness(model, inputs={"indices": indices}, opset=opset)
10241024

10251025

1026+
@pytest.mark.parametrize(
1027+
"reduction, opset, data, indices, updates",
1028+
[
1029+
(
1030+
None,
1031+
11,
1032+
np.array([[1, 2, 3], [4, 5, 6]], dtype="float32"),
1033+
np.array([[2, 0, 1], [1, 2, 0]], dtype="int64"),
1034+
np.array([[30, 10, 20], [50, 60, 40]], dtype="float32"),
1035+
),
1036+
(
1037+
"none",
1038+
18,
1039+
np.array([[1, 2, 3], [4, 5, 6]], dtype="float32"),
1040+
np.array([[2, 0, 1], [1, 2, 0]], dtype="int64"),
1041+
np.array([[30, 10, 20], [50, 60, 40]], dtype="float32"),
1042+
),
1043+
(
1044+
"add",
1045+
16,
1046+
np.full((2, 3), 10, dtype="float32"),
1047+
np.array([[0, 0, 2], [1, 1, 2]], dtype="int64"),
1048+
np.array([[2, 5, 7], [20, 3, 4]], dtype="float32"),
1049+
),
1050+
(
1051+
"mul",
1052+
16,
1053+
np.full((2, 3), 10, dtype="float32"),
1054+
np.array([[0, 0, 2], [1, 1, 2]], dtype="int64"),
1055+
np.array([[2, 5, 7], [20, 3, 4]], dtype="float32"),
1056+
),
1057+
(
1058+
"min",
1059+
18,
1060+
np.full((2, 3), 10, dtype="float32"),
1061+
np.array([[0, 0, 2], [1, 1, 2]], dtype="int64"),
1062+
np.array([[2, 5, 7], [20, 3, 4]], dtype="float32"),
1063+
),
1064+
(
1065+
"max",
1066+
18,
1067+
np.full((2, 3), 10, dtype="float32"),
1068+
np.array([[0, 0, 2], [1, 1, 2]], dtype="int64"),
1069+
np.array([[2, 5, 7], [20, 3, 4]], dtype="float32"),
1070+
),
1071+
],
1072+
)
1073+
def test_scatter_elements_reduction(reduction, opset, data, indices, updates):
1074+
attrs = {"axis": 1}
1075+
if reduction is not None:
1076+
attrs["reduction"] = reduction
1077+
scatter_elements_node = helper.make_node(
1078+
"ScatterElements", ["data", "indices", "updates"], ["output"], **attrs
1079+
)
1080+
1081+
graph = helper.make_graph(
1082+
[scatter_elements_node],
1083+
"scatter_elements_reduction_test",
1084+
inputs=[
1085+
helper.make_tensor_value_info("data", TensorProto.FLOAT, list(data.shape)),
1086+
helper.make_tensor_value_info("indices", TensorProto.INT64, list(indices.shape)),
1087+
helper.make_tensor_value_info("updates", TensorProto.FLOAT, list(updates.shape)),
1088+
],
1089+
outputs=[helper.make_tensor_value_info("output", TensorProto.FLOAT, list(data.shape))],
1090+
)
1091+
model = helper.make_model(graph, producer_name="scatter_elements_reduction_test")
1092+
1093+
check_correctness(
1094+
model,
1095+
inputs={"data": data, "indices": indices, "updates": updates},
1096+
opset=opset,
1097+
)
1098+
1099+
1100+
def test_scatter_elements_invalid_reduction():
1101+
data_shape = [2, 3]
1102+
scatter_elements_node = helper.make_node(
1103+
"ScatterElements",
1104+
["data", "indices", "updates"],
1105+
["output"],
1106+
axis=1,
1107+
reduction="unsupported",
1108+
)
1109+
1110+
graph = helper.make_graph(
1111+
[scatter_elements_node],
1112+
"scatter_elements_invalid_reduction_test",
1113+
inputs=[
1114+
helper.make_tensor_value_info("data", TensorProto.FLOAT, data_shape),
1115+
helper.make_tensor_value_info("indices", TensorProto.INT64, data_shape),
1116+
helper.make_tensor_value_info("updates", TensorProto.FLOAT, data_shape),
1117+
],
1118+
outputs=[helper.make_tensor_value_info("output", TensorProto.FLOAT, data_shape)],
1119+
)
1120+
model = helper.make_model(graph, producer_name="scatter_elements_invalid_reduction_test")
1121+
1122+
with pytest.raises(ValueError, match="Only .* reductions are supported, but got unsupported"):
1123+
from_onnx(model, opset=18, keep_params_in_input=True)
1124+
1125+
10261126
@pytest.mark.parametrize("reduction", ["none", "add", "mul"])
10271127
def test_scatter_nd(reduction):
10281128
def verify_scatter_nd(data_shape, indices_shape, updates_shape):

0 commit comments

Comments
 (0)