Skip to content

Commit 2e6ee08

Browse files
swjngclaude
andauthored
[BugFix] Align tir.round to ties-to-even across all backends (#19368)
## Problem `tir.round` constant-folds using `std::nearbyint` (IEEE 754 ties-to-even), but all backends lower it to platform `round()` which uses ties-away-from-zero. This means compiled code can produce different results from constant-folded code for midpoint values: | Input | Constant-fold (ties-to-even) | Compiled (ties-away) | |-------|-----|------| | 0.5 | 0.0 | 1.0 | | 2.5 | 2.0 | 3.0 | | -0.5 | 0.0 | -1.0 | This was identified as a follow-up to #19367 — see [this comment](#19367 (comment)). ## Fix Align all backends to use ties-to-even intrinsics, matching the constant-folding behavior: | Backend | Before | After | |---------|--------|-------| | LLVM/ROCm/Hexagon | `llvm::Intrinsic::round` | `llvm::Intrinsic::nearbyint` | | NVPTX | `__nv_round[f]` | `__nv_nearbyint[f]` | | CUDA | `round`/`roundf` | `nearbyint`/`nearbyintf` (f16/bf16 already used `hrint`) | | Metal/OpenCL | `round` | `rint` | | Vulkan/SPIR-V | `GLSLstd450Round` | `GLSLstd450RoundEven` | Also fixes OpenCL codegen where `tir.nearbyint` was incorrectly mapped to OpenCL `round()` instead of `rint()`. Updates `op.h` documentation to explicitly state ties-to-even semantics for both `round()` and `nearbyint()`. ## Testing ``` python -m pytest tests/python/tirx-base/test_tir_intrin.py -xvs ``` New `test_round_ties_to_even` verifies midpoint inputs `[0.5, 1.5, 2.5, 3.5, -0.5, -1.5, -2.5, -3.5]` produce ties-to-even results on the LLVM backend. All 12 tests pass (10 passed, 2 skipped for CUDA). --------- Co-authored-by: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent 2345e6e commit 2e6ee08

13 files changed

Lines changed: 87 additions & 20 deletions

File tree

include/tvm/tirx/op.h

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -654,19 +654,25 @@ TVM_DLL PrimExpr floor(PrimExpr x, Span span = Span());
654654
TVM_DLL PrimExpr ceil(PrimExpr x, Span span = Span());
655655

656656
/*!
657-
* \brief Calculate round(x)
657+
* \brief Round x to the nearest integer, ties to even.
658+
*
659+
* Uses IEEE 754 default rounding mode (ties-to-even / banker's rounding).
660+
* Constant-folding and all backends consistently use std::nearbyint semantics.
661+
*
658662
* \param x The input expression.
659663
* \param span The location of this operation in the source.
660664
* \return The result expression.
661665
*/
662666
TVM_DLL PrimExpr round(PrimExpr x, Span span = Span());
663667

664668
/*!
665-
* \brief Calculates std::nearbyint(x)
669+
* \brief Round x to the nearest integer, ties to even.
670+
*
671+
* Equivalent to round(). Both use IEEE 754 default rounding mode (ties-to-even).
672+
*
666673
* \param x The input expression.
667674
* \param span The location of this operation in the source.
668675
* \return The result expression.
669-
* This is a faster alternate to round.
670676
*/
671677
TVM_DLL PrimExpr nearbyint(PrimExpr x, Span span = Span());
672678

python/tvm/topi/testing/roi_pool_python.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -36,10 +36,12 @@ def roi_pool_nchw_python(a_np, rois_np, pooled_size, spatial_scale):
3636
for i in range(num_roi):
3737
roi = rois_np[i]
3838
batch_index = int(roi[0])
39-
roi_start_w = round(roi[1] * spatial_scale)
40-
roi_start_h = round(roi[2] * spatial_scale)
41-
roi_end_w = round(roi[3] * spatial_scale)
42-
roi_end_h = round(roi[4] * spatial_scale)
39+
# Use ties-away-from-zero rounding to match ONNX runtime (std::round semantics).
40+
# Python's built-in round() uses ties-to-even, so use floor(x + 0.5) explicitly.
41+
roi_start_w = math.floor(roi[1] * spatial_scale + 0.5)
42+
roi_start_h = math.floor(roi[2] * spatial_scale + 0.5)
43+
roi_end_w = math.floor(roi[3] * spatial_scale + 0.5)
44+
roi_end_h = math.floor(roi[4] * spatial_scale + 0.5)
4345
roi_h = max(roi_end_h - roi_start_h + 1, 1)
4446
roi_w = max(roi_end_w - roi_start_w + 1, 1)
4547

python/tvm/topi/vision/roi_pool.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -36,12 +36,19 @@ def roi_pool_nchw(data, rois, pooled_size, spatial_scale):
3636

3737
neg_inf = tvm.tirx.const(float("-inf"), data.dtype)
3838

39+
def _round_away(x):
40+
# ONNX MaxRoiPool spec uses ties-away-from-zero rounding for coordinate
41+
# mapping (matching std::round semantics in the reference implementation).
42+
# Use floor(x + 0.5) to be explicit and independent of tir.round semantics.
43+
half = tvm.tirx.const(0.5, roi_dtype)
44+
return te.floor(x + half)
45+
3946
def _bin_bounds(i, ph, pw):
4047
roi = rois[i]
41-
roi_start_w = te.round(roi[1] * spatial_scale).astype("int32")
42-
roi_start_h = te.round(roi[2] * spatial_scale).astype("int32")
43-
roi_end_w = te.round(roi[3] * spatial_scale).astype("int32")
44-
roi_end_h = te.round(roi[4] * spatial_scale).astype("int32")
48+
roi_start_w = _round_away(roi[1] * spatial_scale).astype("int32")
49+
roi_start_h = _round_away(roi[2] * spatial_scale).astype("int32")
50+
roi_end_w = _round_away(roi[3] * spatial_scale).astype("int32")
51+
roi_end_h = _round_away(roi[4] * spatial_scale).astype("int32")
4552

4653
roi_h = te.max(roi_end_h - roi_start_h + 1, tvm.tirx.const(1, "int32"))
4754
roi_w = te.max(roi_end_w - roi_start_w + 1, tvm.tirx.const(1, "int32"))

src/target/llvm/intrin_rule_hexagon.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,7 @@ TVM_REGISTER_OP("tirx.fabs")
9393

9494
TVM_REGISTER_OP("tirx.round")
9595
.set_attr<FLowerIntrinsic>("hexagon.FLowerIntrinsic",
96-
DispatchLLVMPureIntrin<::llvm::Intrinsic::round, 1>);
96+
DispatchLLVMPureIntrin<::llvm::Intrinsic::nearbyint, 1>);
9797

9898
TVM_REGISTER_OP("tirx.ctpop")
9999
.set_attr<FLowerIntrinsic>("hexagon.FLowerIntrinsic",

src/target/llvm/intrin_rule_llvm.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,7 @@ TVM_REGISTER_OP("tirx.fabs")
9090

9191
TVM_REGISTER_OP("tirx.round")
9292
.set_attr<FLowerIntrinsic>("llvm.FLowerIntrinsic",
93-
DispatchLLVMPureIntrin<::llvm::Intrinsic::round, 1>);
93+
DispatchLLVMPureIntrin<::llvm::Intrinsic::nearbyint, 1>);
9494

9595
TVM_REGISTER_OP("tirx.nearbyint")
9696
.set_attr<FLowerIntrinsic>("llvm.FLowerIntrinsic",

src/target/llvm/intrin_rule_nvptx.cc

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,15 @@ TVM_REGISTER_OP("tirx.ceil")
6666
.set_attr<FLowerIntrinsic>("nvptx.FLowerIntrinsic", DispatchPureExternLibDevice);
6767

6868
TVM_REGISTER_OP("tirx.round")
69-
.set_attr<FLowerIntrinsic>("nvptx.FLowerIntrinsic", DispatchPureExternLibDevice);
69+
.set_attr<FLowerIntrinsic>("nvptx.FLowerIntrinsic", [](const PrimExpr& e) -> PrimExpr {
70+
// Redirect to nearbyint (ties-to-even) to match constant-folding semantics.
71+
using namespace tirx;
72+
const CallNode* call = e.as<CallNode>();
73+
TVM_FFI_ICHECK(call != nullptr);
74+
auto nearbyint_op = Op::Get("tirx.nearbyint");
75+
auto new_call = Call(call->dtype, nearbyint_op, call->args);
76+
return DispatchPureExternLibDevice(new_call);
77+
});
7078

7179
TVM_REGISTER_OP("tirx.nearbyint")
7280
.set_attr<FLowerIntrinsic>("nvptx.FLowerIntrinsic", DispatchPureExternLibDevice);

src/target/llvm/intrin_rule_rocm.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -132,7 +132,7 @@ TVM_REGISTER_OP("tirx.ceil")
132132

133133
TVM_REGISTER_OP("tirx.round")
134134
.set_attr<FLowerIntrinsic>("rocm.FLowerIntrinsic",
135-
DispatchLLVMPureIntrin<::llvm::Intrinsic::round, 1>);
135+
DispatchLLVMPureIntrin<::llvm::Intrinsic::nearbyint, 1>);
136136

137137
TVM_REGISTER_OP("tirx.nearbyint")
138138
.set_attr<FLowerIntrinsic>("rocm.FLowerIntrinsic",

src/target/source/codegen_opencl.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -526,7 +526,7 @@ void CodeGenOpenCL::VisitExpr_(const CallNode* op, std::ostream& os) {
526526
this->PrintCallExtern(GetType(ffi::GetRef<PrimExpr>(op)), "atomic_add_float_emu", op->args,
527527
true, os);
528528
} else if (func->value == "nearbyint") {
529-
this->PrintCallExtern(GetType(ffi::GetRef<PrimExpr>(op)), "round", op->args, true, os);
529+
this->PrintCallExtern(GetType(ffi::GetRef<PrimExpr>(op)), "rint", op->args, true, os);
530530
} else {
531531
if (func->value == "atomic_add") {
532532
enable_atomics_ = true;

src/target/source/intrin_rule_cuda.cc

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,8 +37,11 @@ struct CUDAMath {
3737
if (t.is_float()) {
3838
switch (t.bits()) {
3939
case 64:
40+
// Use nearbyint (ties-to-even) for round to match constant-folding semantics.
41+
if (name == "round") return "nearbyint";
4042
return name;
4143
case 32:
44+
if (name == "round") return "nearbyintf";
4245
return name + 'f';
4346
case 16: {
4447
if (name == "fabs") {

src/target/source/intrin_rule_metal.cc

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,16 @@ TVM_REGISTER_OP("tirx.fabs")
6868
.set_attr<FLowerIntrinsic>("metal.FLowerIntrinsic", DispatchPureExtern<Direct>);
6969

7070
TVM_REGISTER_OP("tirx.round")
71-
.set_attr<FLowerIntrinsic>("metal.FLowerIntrinsic", DispatchPureExtern<Direct>);
71+
.set_attr<FLowerIntrinsic>("metal.FLowerIntrinsic", [](const PrimExpr& e) -> PrimExpr {
72+
// Metal's rint() uses ties-to-even, matching constant-folding semantics.
73+
const tirx::CallNode* call = e.as<tirx::CallNode>();
74+
TVM_FFI_ICHECK(call != nullptr);
75+
ffi::Array<PrimExpr> new_args = {tirx::StringImm("rint")};
76+
for (auto arg : call->args) {
77+
new_args.push_back(arg);
78+
}
79+
return tirx::Call(call->dtype, tirx::builtin::call_pure_extern(), new_args);
80+
});
7281

7382
TVM_REGISTER_OP("tirx.nearbyint")
7483
.set_attr<FLowerIntrinsic>("metal.FLowerIntrinsic", DispatchPureExtern<Direct>);

0 commit comments

Comments
 (0)