Skip to content

Commit a44208d

Browse files
committed
[Lang] qd.precise: propagate tag in 2*a rewrite, narrow zero-fold gate, refresh test_api
1 parent c5fbab6 commit a44208d

2 files changed

Lines changed: 11 additions & 4 deletions

File tree

quadrants/transforms/alg_simp.cpp

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -377,6 +377,9 @@ class AlgSimp : public BasicStmtVisitor {
377377
auto sum = Stmt::make<BinaryOpStmt>(BinaryOpType::add, a, a);
378378
sum->ret_type = a->ret_type;
379379
sum->dbg_info = stmt->dbg_info;
380+
// `2 * a` and `a + a` are IEEE-equivalent, but the synthesized add must carry `precise` so the
381+
// downstream FMF clear / NoContraction plumbing still sees the user's opt-in tag.
382+
static_cast<BinaryOpStmt *>(sum.get())->precise = stmt->precise;
380383
stmt->replace_usages_with(sum.get());
381384
modifier.insert_before(stmt, std::move(sum));
382385
modifier.erase(stmt);
@@ -442,12 +445,15 @@ class AlgSimp : public BasicStmtVisitor {
442445
optimize_division(stmt);
443446
} else if (stmt->op_type == BinaryOpType::add || stmt->op_type == BinaryOpType::sub ||
444447
stmt->op_type == BinaryOpType::bit_or || stmt->op_type == BinaryOpType::bit_xor) {
445-
if (alg_is_zero(rhs) && !stmt->precise) {
446-
// a +-|^ 0 -> a. Skipped when `stmt->precise` is set: `(-0.0) + 0.0` yields `+0.0` under IEEE.
448+
const bool precise_fp_add = stmt->precise && stmt->op_type == BinaryOpType::add;
449+
if (alg_is_zero(rhs) && !precise_fp_add) {
450+
// a +-|^ 0 -> a. Skipped only for `precise` FP adds: `(-0.0) + 0.0` yields `+0.0` under IEEE.
451+
// `a - 0 -> a` is IEEE-exact for every `a` and `bit_or`/`bit_xor` are integer ops, so they
452+
// stay unconditional.
447453
stmt->replace_usages_with(stmt->lhs);
448454
modifier.erase(stmt);
449-
} else if (stmt->op_type != BinaryOpType::sub && alg_is_zero(lhs) && !stmt->precise) {
450-
// 0 +|^ a -> a. Skipped when `stmt->precise` is set (same signed-zero reasoning).
455+
} else if (stmt->op_type != BinaryOpType::sub && alg_is_zero(lhs) && !precise_fp_add) {
456+
// 0 +|^ a -> a. Same reasoning.
451457
stmt->replace_usages_with(stmt->rhs);
452458
modifier.erase(stmt);
453459
} else if (stmt->op_type == BinaryOpType::bit_or && irpass::analysis::same_value(stmt->lhs, stmt->rhs)) {

tests/python/test_api.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -188,6 +188,7 @@ def _get_expected_matrix_apis():
188188
"perf_dispatch",
189189
"polar_decompose",
190190
"pow",
191+
"precise",
191192
"profiler",
192193
"pure",
193194
"pyfunc",

0 commit comments

Comments
 (0)