Skip to content

Commit b1918c7

Browse files
[Fix][Relax]: ONNX Clip NaN bounds and preserve input NaN (ORT parity) (#19535)
This PR fixes #19533: - Sanitize floating tensor min/max: replace NaN with +inf/-inf before topi max/min so bounds match ONNX "unbounded" semantics where NaN bounds default to no constraint. - After clamping, preserve NaNs from the input tensor on floating dtypes. - Extend check_correctness with equal_nan for float outputs containing NaN. - Add parametrized Clip opset-13 tests for NaN min/max tensor bounds.
1 parent 378c4f3 commit b1918c7

3 files changed

Lines changed: 83 additions & 4 deletions

File tree

python/tvm/relax/frontend/onnx/onnx_frontend.py

Lines changed: 33 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -52,12 +52,26 @@
5252
from tvm import TVMError, relax, tirx, topi
5353
from tvm.ir import IRModule
5454
from tvm.ir.supply import NameSupply
55+
from tvm.runtime import DataType, DataTypeCode
5556
from tvm.tirx.generic import cast
5657
from tvm.topi.utils import get_const_tuple
5758

5859
from ..common import autopad
5960

6061

62+
def _relax_dtype_is_floating_point(dtype: str) -> bool:
63+
"""Whether a Relax dtype string is a floating point type."""
64+
try:
65+
code = DataType(dtype).type_code
66+
except (ValueError, TypeError, TVMError):
67+
return False
68+
return (
69+
code == DataTypeCode.FLOAT
70+
or code == DataTypeCode.BFLOAT
71+
or (code >= DataTypeCode.Float8E3M4 and code <= DataTypeCode.Float4E2M1FN)
72+
)
73+
74+
6175
def get_type(elem_type: str | int) -> str:
6276
"""Converts onnx integer datatype to numpy datatype"""
6377
# If a string was passed instead of a tensor type, it does not need
@@ -311,6 +325,7 @@ def get_converter(cls, opset):
311325
return getattr(cls, f"_impl_v{version}")
312326
raise NotImplementedError(f"opset version {version} of {cls.__name__} not implemented")
313327

328+
314329
class QuantizeLinear(OnnxOpConverter):
315330
@classmethod
316331
def _impl_v10(cls, bb, inputs, attr, params):
@@ -379,6 +394,7 @@ def _impl_v11(cls, bb, inputs, attr, params):
379394
y = relax.op.quantize(x, y_scale, y_zero_point, axis=0, out_dtype="uint8")
380395
return relax.Tuple([y, y_scale, y_zero_point])
381396

397+
382398
class MatMul(OnnxOpConverter):
383399
"""Converts an onnx MatMul node into an equivalent Relax expression."""
384400

@@ -1350,6 +1366,15 @@ def _impl_v16(cls, bb, inputs, attr, params):
13501366
class Clip(OnnxOpConverter):
13511367
"""Converts an onnx Clip node into an equivalent Relax expression."""
13521368

1369+
@staticmethod
1370+
def _sanitize_nan_clip_bound(bb, bound: relax.Expr, *, for_min: bool) -> relax.Expr:
1371+
"""ONNX/ORT treat NaN clip bounds as unbounded; plain max/min with NaN poisons output."""
1372+
dtype = bound.struct_info.dtype
1373+
if not _relax_dtype_is_floating_point(dtype):
1374+
return bound
1375+
repl = -_np.inf if for_min else _np.inf
1376+
return bb.emit(relax.op.where(relax.op.isnan(bound), relax.const(repl, dtype), bound))
1377+
13531378
@classmethod
13541379
def _impl_v1(cls, bb, inputs, attr, params):
13551380
min = float(attr.get("min", -_np.inf))
@@ -1366,11 +1391,16 @@ def _impl_v11(cls, bb, inputs, attr, params):
13661391

13671392
@classmethod
13681393
def _impl_v13(cls, bb, inputs, attr, params):
1369-
results = inputs[0]
1394+
x: Any = inputs[0]
1395+
results = x
13701396
if inputs[1] is not None:
1371-
results = bb.emit_te(topi.maximum, results, inputs[1])
1397+
lo = cls._sanitize_nan_clip_bound(bb, inputs[1], for_min=True)
1398+
results = bb.emit_te(topi.maximum, results, lo)
13721399
if inputs[2] is not None:
1373-
results = bb.emit_te(topi.minimum, results, inputs[2])
1400+
hi = cls._sanitize_nan_clip_bound(bb, inputs[2], for_min=False)
1401+
results = bb.emit_te(topi.minimum, results, hi)
1402+
if _relax_dtype_is_floating_point(x.struct_info.dtype):
1403+
results = bb.emit(relax.op.where(relax.op.isnan(x), x, results))
13741404
return results
13751405

13761406

tests/python/relax/test_frontend_onnx.py

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1597,6 +1597,55 @@ def test_clip_v6(max, min):
15971597
check_correctness(model, opset=10)
15981598

15991599

1600+
@pytest.mark.parametrize(
1601+
"min,max",
1602+
[
1603+
pytest.param(
1604+
np.array(0.0, dtype=np.float32),
1605+
np.array(6.0, dtype=np.float32),
1606+
),
1607+
pytest.param(
1608+
np.array(0.0, dtype=np.float32),
1609+
np.array(np.nan, dtype=np.float32),
1610+
),
1611+
pytest.param(
1612+
np.array(np.nan, dtype=np.float32),
1613+
np.array(6.0, dtype=np.float32),
1614+
),
1615+
pytest.param(
1616+
np.array(np.nan, dtype=np.float32),
1617+
np.array(np.nan, dtype=np.float32),
1618+
),
1619+
],
1620+
)
1621+
@pytest.mark.parametrize(
1622+
"input",
1623+
[
1624+
np.array([0.5, -3.0, 4.5, 11.0, 7.0], dtype=np.float32),
1625+
np.array([0.5, -3.0, 4.5, 11.0, np.nan], dtype=np.float32),
1626+
],
1627+
)
1628+
def test_clip_v13(input, min, max):
1629+
# Opset 13: tensor min/max. NaN bound => unbounded on that side (ORT); input NaN preserved.
1630+
clip_node = helper.make_node("Clip", ["input", "min", "max"], ["output"])
1631+
graph = helper.make_graph(
1632+
[clip_node],
1633+
"clip_v13_nan_max",
1634+
inputs=[
1635+
helper.make_tensor_value_info("input", TensorProto.FLOAT, [5]),
1636+
helper.make_tensor_value_info("min", TensorProto.FLOAT, []),
1637+
helper.make_tensor_value_info("max", TensorProto.FLOAT, []),
1638+
],
1639+
outputs=[helper.make_tensor_value_info("output", TensorProto.FLOAT, [5])],
1640+
)
1641+
model = helper.make_model(graph, producer_name="clip_v13_nan_max")
1642+
check_correctness(
1643+
model,
1644+
inputs={"input": input, "min": min, "max": max},
1645+
opset=13,
1646+
)
1647+
1648+
16001649
def test_equal():
16011650
equal_node = helper.make_node("Equal", ["a", "b"], ["output"])
16021651

tests/python/s_tir/meta_schedule/test_meta_schedule_search_strategy.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -324,7 +324,7 @@ def __str__(self) -> str:
324324
assert candidates is None
325325

326326

327-
def test_meta_schedule_evolutionary_search_skip_invalid_measured_trace() # pylint: disable = invalid-name
327+
def test_meta_schedule_evolutionary_search_skip_invalid_measured_trace(): # pylint: disable = invalid-name
328328
# Construct an incompatible measured trace: it references block name "other",
329329
# which doesn't exist in Matmul. Replaying this trace should fail and be skipped.
330330
wrong_sch = Schedule(OtherBlock)

0 commit comments

Comments
 (0)