Expected behavior
TVM Relax should execute ONNX Sign consistently with ONNX Runtime when the input contains NaN.
For this input:
x = [NaN, 9.0, -9.0, NaN]
ONNX Runtime returns:
[nan, 1.0, -1.0, nan]
Actual behavior
TVM Relax returns 0.0 for NaN inputs:
ORT: [nan, 1.0, -1.0, nan]
TVM: [0.0, 1.0, -1.0, 0.0]
The discrepancy appears when importing an ONNX Sign model through the Relax ONNX frontend and compiling it for the llvm target.
Environment
TVM: 0.14 environment / Relax ONNX frontend
ONNX Runtime: 1.23
Python: 3.11
Target: llvm
OS: Linux
Steps to reproduce
import numpy as np
import onnx
import onnxruntime as ort
from onnx import TensorProto, helper
import tvm
from tvm import relax
from tvm.relax.frontend.onnx import from_onnx
node = helper.make_node("Sign", ["x"], ["y"])
graph = helper.make_graph(
[node],
"g",
[helper.make_tensor_value_info("x", TensorProto.FLOAT, [4])],
[helper.make_tensor_value_info("y", TensorProto.FLOAT, [4])],
)
model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 18)])
model.ir_version = 9
x = np.array([np.nan, 9.0, -9.0, np.nan], dtype=np.float32)
ort_out = ort.InferenceSession(
model.SerializeToString(),
providers=["CPUExecutionProvider"],
).run(None, {"x": x})[0]
mod = from_onnx(model, keep_params_in_input=False)
with tvm.transform.PassContext(opt_level=3):
ex = tvm.compile(mod, target=tvm.target.Target("llvm"))
vm = relax.VirtualMachine(ex, tvm.cpu())
out = vm["main"](tvm.runtime.tensor(x, tvm.cpu()))
tvm_out = (out[0] if isinstance(out, (list, tuple)) else out).numpy()
print("ORT:", ort_out.tolist())
print("TVM:", tvm_out.tolist())
Triage
Expected behavior
TVM Relax should execute ONNX
Signconsistently with ONNX Runtime when the input containsNaN.For this input:
x = [NaN, 9.0, -9.0, NaN]ONNX Runtime returns:
[nan, 1.0, -1.0, nan]Actual behavior
TVM Relax returns 0.0 for
NaNinputs:The discrepancy appears when importing an ONNX
Signmodel through the Relax ONNX frontend and compiling it for thellvmtarget.Environment
TVM: 0.14 environment / Relax ONNX frontend
ONNX Runtime: 1.23
Python: 3.11
Target: llvm
OS: Linux
Steps to reproduce
Triage