Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 33 additions & 3 deletions python/tvm/relax/frontend/onnx/onnx_frontend.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,12 +52,26 @@
from tvm import TVMError, relax, tirx, topi
from tvm.ir import IRModule
from tvm.ir.supply import NameSupply
from tvm.runtime import DataType, DataTypeCode
from tvm.tirx.generic import cast
from tvm.topi.utils import get_const_tuple

from ..common import autopad


def _relax_dtype_is_floating_point(dtype: str) -> bool:
"""Whether a Relax dtype string is a floating point type."""
try:
code = DataType(dtype).type_code
except (ValueError, TypeError, TVMError):
return False
return (
code == DataTypeCode.FLOAT
or code == DataTypeCode.BFLOAT
or (code >= DataTypeCode.Float8E3M4 and code <= DataTypeCode.Float4E2M1FN)
Comment thread
ConvolutedDog marked this conversation as resolved.
)


def get_type(elem_type: str | int) -> str:
"""Converts onnx integer datatype to numpy datatype"""
# If a string was passed instead of a tensor type, it does not need
Expand Down Expand Up @@ -311,6 +325,7 @@ def get_converter(cls, opset):
return getattr(cls, f"_impl_v{version}")
raise NotImplementedError(f"opset version {version} of {cls.__name__} not implemented")


class QuantizeLinear(OnnxOpConverter):
@classmethod
def _impl_v10(cls, bb, inputs, attr, params):
Expand Down Expand Up @@ -379,6 +394,7 @@ def _impl_v11(cls, bb, inputs, attr, params):
y = relax.op.quantize(x, y_scale, y_zero_point, axis=0, out_dtype="uint8")
return relax.Tuple([y, y_scale, y_zero_point])


class MatMul(OnnxOpConverter):
"""Converts an onnx MatMul node into an equivalent Relax expression."""

Expand Down Expand Up @@ -1309,6 +1325,15 @@ def _impl_v16(cls, bb, inputs, attr, params):
class Clip(OnnxOpConverter):
"""Converts an onnx Clip node into an equivalent Relax expression."""

@staticmethod
def _sanitize_nan_clip_bound(bb, bound: relax.Expr, *, for_min: bool) -> relax.Expr:
"""ONNX/ORT treat NaN clip bounds as unbounded; plain max/min with NaN poisons output."""
dtype = bound.struct_info.dtype
if not _relax_dtype_is_floating_point(dtype):
return bound
repl = -_np.inf if for_min else _np.inf
return bb.emit(relax.op.where(relax.op.isnan(bound), relax.const(repl, dtype), bound))

@classmethod
def _impl_v1(cls, bb, inputs, attr, params):
min = float(attr.get("min", -_np.inf))
Expand All @@ -1325,11 +1350,16 @@ def _impl_v11(cls, bb, inputs, attr, params):

@classmethod
def _impl_v13(cls, bb, inputs, attr, params):
results = inputs[0]
x: Any = inputs[0]
results = x
if inputs[1] is not None:
results = bb.emit_te(topi.maximum, results, inputs[1])
lo = cls._sanitize_nan_clip_bound(bb, inputs[1], for_min=True)
results = bb.emit_te(topi.maximum, results, lo)
if inputs[2] is not None:
results = bb.emit_te(topi.minimum, results, inputs[2])
hi = cls._sanitize_nan_clip_bound(bb, inputs[2], for_min=False)
results = bb.emit_te(topi.minimum, results, hi)
if _relax_dtype_is_floating_point(x.struct_info.dtype):
results = bb.emit(relax.op.where(relax.op.isnan(x), x, results))
return results


Expand Down
49 changes: 49 additions & 0 deletions tests/python/relax/test_frontend_onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -1435,6 +1435,55 @@ def test_clip_v6(max, min):
check_correctness(model, opset=10)


@pytest.mark.parametrize(
"min,max",
[
pytest.param(
np.array(0.0, dtype=np.float32),
np.array(6.0, dtype=np.float32),
),
pytest.param(
np.array(0.0, dtype=np.float32),
np.array(np.nan, dtype=np.float32),
),
Comment thread
ConvolutedDog marked this conversation as resolved.
pytest.param(
np.array(np.nan, dtype=np.float32),
np.array(6.0, dtype=np.float32),
),
pytest.param(
np.array(np.nan, dtype=np.float32),
np.array(np.nan, dtype=np.float32),
),
],
)
@pytest.mark.parametrize(
"input",
[
np.array([0.5, -3.0, 4.5, 11.0, 7.0], dtype=np.float32),
np.array([0.5, -3.0, 4.5, 11.0, np.nan], dtype=np.float32),
],
)
def test_clip_v13(input, min, max):
# Opset 13: tensor min/max. NaN bound => unbounded on that side (ORT); input NaN preserved.
clip_node = helper.make_node("Clip", ["input", "min", "max"], ["output"])
graph = helper.make_graph(
[clip_node],
"clip_v13_nan_max",
inputs=[
helper.make_tensor_value_info("input", TensorProto.FLOAT, [5]),
helper.make_tensor_value_info("min", TensorProto.FLOAT, []),
helper.make_tensor_value_info("max", TensorProto.FLOAT, []),
],
outputs=[helper.make_tensor_value_info("output", TensorProto.FLOAT, [5])],
)
model = helper.make_model(graph, producer_name="clip_v13_nan_max")
check_correctness(
model,
inputs={"input": input, "min": min, "max": max},
opset=13,
)


def test_equal():
equal_node = helper.make_node("Equal", ["a", "b"], ["output"])

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -324,7 +324,7 @@ def __str__(self) -> str:
assert candidates is None


def test_meta_schedule_evolutionary_search_skip_invalid_measured_trace() # pylint: disable = invalid-name
def test_meta_schedule_evolutionary_search_skip_invalid_measured_trace(): # pylint: disable = invalid-name
# Construct an incompatible measured trace: it references block name "other",
# which doesn't exist in Matmul. Replaying this trace should fail and be skipped.
wrong_sch = Schedule(OtherBlock)
Expand Down
Loading