Skip to content

Commit 74ed86b

Browse files
[Relax][Frontend][Onnx] Add support for pad-2 (#17431)
* fix params name bug * add support for onnx pad_v2 * Update test_frontend_onnx.py * Update onnx_frontend.py
1 parent 910ee0e commit 74ed86b

2 files changed

Lines changed: 86 additions & 0 deletions

File tree

python/tvm/relax/frontend/onnx/onnx_frontend.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1582,6 +1582,35 @@ def _impl_v13(cls, bb, inputs, attr, params):
15821582
class Pad(OnnxOpConverter):
15831583
"""Converts an onnx Pad node into an equivalent Relax expression."""
15841584

1585+
@classmethod
1586+
def _impl_v2(cls, bb, inputs, attr, params):
1587+
pads = attr.get("pads")
1588+
pads = relax.const(_np.array(pads), inputs[0].struct_info.shape[0].dtype)
1589+
constant_value = attr.get("value")
1590+
if constant_value is None:
1591+
constant_value = 0.0
1592+
1593+
if isinstance(pads, relax.Constant):
1594+
pad_before, pad_after = _np.split(pads.data.numpy(), 2)
1595+
pad_before = _np.ndarray.tolist(pad_before)
1596+
pad_after = _np.ndarray.tolist(pad_after)
1597+
else:
1598+
raise ValueError("Dynamic pads are not supported yet.")
1599+
1600+
pad_mode = attr.get("mode", b"constant").decode("utf-8")
1601+
if not pad_mode in ["constant", "edge", "reflect"]:
1602+
raise tvm.error.OpAttributeInvalid(
1603+
"Value " + pad_mode + ' in attribute "mode" is invalid for operator Pad.'
1604+
)
1605+
1606+
if pad_mode == "constant":
1607+
return bb.emit_te(topi.nn.pad, inputs[0], pad_before, pad_after, constant_value)
1608+
elif pad_mode == "reflect":
1609+
return bb.emit_te(topi.nn.mirror_pad, inputs[0], pad_before, pad_after, "REFLECT")
1610+
else:
1611+
# TODO(gigiblender) Support edge mode.
1612+
raise NotImplementedError("Pad mode {} not implemented".format(pad_mode))
1613+
15851614
@classmethod
15861615
def _impl_v11(cls, bb, inputs, attr, params):
15871616
pads = get_constant(inputs[1], params)

tests/python/relax/test_frontend_onnx.py

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1696,6 +1696,63 @@ def verify_pad(input_shape, pads, mode="constant", value=0.0):
16961696
verify_pad((1, 3, 4, 5), [0, 1, 1, 1, 0, 0, 1, 1], "reflect")
16971697

16981698

1699+
@pytest.mark.parametrize("dynamic", [True, False])
1700+
def test_pad_v2(dynamic):
1701+
1702+
if dynamic:
1703+
pytest.skip("Dynamic pad not supported")
1704+
1705+
def verify_pad(input_shape, pads, mode="constant", value=0.0):
1706+
indata = np.random.normal(size=input_shape).astype(np.float32)
1707+
# numpy expect result
1708+
len_dim = len(pads) // 2
1709+
np_pads = [(pads[i], pads[i + len_dim]) for i in range(len_dim)]
1710+
pads = np.array(pads)
1711+
# onnx graph
1712+
if mode in ["edge", "reflect"]:
1713+
outdata = np.pad(indata, pad_width=np_pads, mode=mode)
1714+
node = helper.make_node(
1715+
"Pad", inputs=["input"], outputs=["output"], mode=mode, pads=pads
1716+
)
1717+
graph = helper.make_graph(
1718+
[node],
1719+
"pad_test",
1720+
inputs=[
1721+
helper.make_tensor_value_info("input", TensorProto.FLOAT, list(indata.shape))
1722+
],
1723+
outputs=[
1724+
helper.make_tensor_value_info("output", TensorProto.FLOAT, list(outdata.shape))
1725+
],
1726+
)
1727+
else:
1728+
outdata = np.pad(indata, pad_width=np_pads, mode="constant", constant_values=value)
1729+
node = helper.make_node(
1730+
"Pad",
1731+
inputs=["input"],
1732+
outputs=["output"],
1733+
mode="constant",
1734+
pads=pads,
1735+
value=value,
1736+
)
1737+
graph = helper.make_graph(
1738+
[node],
1739+
"pad_test",
1740+
inputs=[
1741+
helper.make_tensor_value_info("input", TensorProto.FLOAT, list(indata.shape))
1742+
],
1743+
outputs=[
1744+
helper.make_tensor_value_info("output", TensorProto.FLOAT, list(outdata.shape))
1745+
],
1746+
)
1747+
model = helper.make_model(graph, producer_name="pad_test")
1748+
check_correctness(model=model, opset=10)
1749+
1750+
verify_pad((2, 2), [0, 1, 0, 0], "constant", 0.0)
1751+
verify_pad((2, 3), [1, 0, 0, 1], "constant", 0.0)
1752+
verify_pad((3, 2), [0, 0, 1, 0], "constant", 5.0)
1753+
verify_pad((1, 3, 4, 5), [0, 1, 1, 1, 0, 0, 1, 1], "reflect")
1754+
1755+
16991756
@pytest.mark.parametrize("fp_arith", [np.float16, np.float32])
17001757
@pytest.mark.parametrize("dynamic", [True, False])
17011758
def test_split(fp_arith, dynamic):

0 commit comments

Comments
 (0)