Skip to content

Commit c5fbab6

Browse files
committed
[Lang] qd.precise: factor disable_fast_math helper, add Vector/select tests
1 parent 97d7fb6 commit c5fbab6

2 files changed

Lines changed: 79 additions & 29 deletions

File tree

quadrants/codegen/llvm/codegen_llvm.cpp

Lines changed: 23 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,27 @@
2222

2323
namespace quadrants::lang {
2424

25+
namespace {
26+
27+
// Clear every fast-math flag on the FP instruction backing `v`, so LLVM cannot reassociate, contract, or
28+
// substitute approximations (e.g. sqrt -> rsqrt+refine, sin -> libm fast variant). No-op if `v` is not an
29+
// FPMathOperator. Note: `setFastMathFlags(FastMathFlags{})` only OR's in flags on this LLVM version, so
30+
// each flag has to be cleared individually.
31+
void disable_fast_math(llvm::Value *v) {
32+
auto *inst = llvm::dyn_cast<llvm::Instruction>(v);
33+
if (!inst || !llvm::isa<llvm::FPMathOperator>(inst))
34+
return;
35+
inst->setHasAllowReassoc(false);
36+
inst->setHasNoNaNs(false);
37+
inst->setHasNoInfs(false);
38+
inst->setHasNoSignedZeros(false);
39+
inst->setHasAllowReciprocal(false);
40+
inst->setHasAllowContract(false);
41+
inst->setHasApproxFunc(false);
42+
}
43+
44+
} // namespace
45+
2546
// TODO: sort function definitions to match declaration order in header
2647

2748
// TODO(k-ye): Hide FunctionCreationGuard inside cpp file
@@ -472,21 +493,8 @@ void TaskCodeGenLLVM::visit(UnaryOpStmt *stmt) {
472493
}
473494
#undef UNARY_INTRINSIC
474495

475-
// qd.precise(...) marks this op as IEEE-strict: clear every fast-math flag (inherited from the module-level
476-
// `fast_math` setting via the IRBuilder default) so LLVM cannot substitute approximate variants (e.g.
477-
// sqrt -> rsqrt+refine, sin -> libm fast variant) or otherwise simplify this instruction.
478496
if (stmt->precise) {
479-
if (auto *inst = llvm::dyn_cast<llvm::Instruction>(llvm_val[stmt])) {
480-
if (llvm::isa<llvm::FPMathOperator>(inst)) {
481-
inst->setHasAllowReassoc(false);
482-
inst->setHasNoNaNs(false);
483-
inst->setHasNoInfs(false);
484-
inst->setHasNoSignedZeros(false);
485-
inst->setHasAllowReciprocal(false);
486-
inst->setHasAllowContract(false);
487-
inst->setHasApproxFunc(false);
488-
}
489-
}
497+
disable_fast_math(llvm_val[stmt]);
490498
}
491499
}
492500

@@ -765,22 +773,8 @@ void TaskCodeGenLLVM::visit(BinaryOpStmt *stmt) {
765773
}
766774
}
767775

768-
// qd.precise(...) marks this op as IEEE-strict: clear every fast-math flag (inherited from the module-level
769-
// `fast_math` setting via the IRBuilder default) so LLVM can't reassociate, contract, or otherwise simplify
770-
// this instruction. Note: `setFastMathFlags(empty)` only OR's in flags on this LLVM version, so we have to
771-
// clear each individual flag.
772776
if (stmt->precise) {
773-
if (auto *inst = llvm::dyn_cast<llvm::Instruction>(llvm_val[stmt])) {
774-
if (llvm::isa<llvm::FPMathOperator>(inst)) {
775-
inst->setHasAllowReassoc(false);
776-
inst->setHasNoNaNs(false);
777-
inst->setHasNoInfs(false);
778-
inst->setHasNoSignedZeros(false);
779-
inst->setHasAllowReciprocal(false);
780-
inst->setHasAllowContract(false);
781-
inst->setHasApproxFunc(false);
782-
}
783-
}
777+
disable_fast_math(llvm_val[stmt]);
784778
}
785779
}
786780

tests/python/test_precise.py

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
"""
99

1010
import numpy as np
11+
import pytest
1112

1213
import quadrants as qd
1314

@@ -169,3 +170,58 @@ def k(x: qd.types.ndarray(qd.f32, ndim=1), out: qd.types.ndarray(qd.f32, ndim=2)
169170
f"qd.precise(unary) deviated from the correctly-rounded f32 reference by {max_ulp:.2f} ULP. "
170171
f"The unary precise tag is not reaching the codegen for at least one of sin/cos/log/sqrt."
171172
)
173+
174+
175+
@test_utils.test(default_fp=qd.f32)
176+
def test_qd_precise_rejects_quadrants_classes():
177+
"""`qd.precise` is a scalar primitive. Wrapping a `Vector` or `Matrix` must raise so that users who
178+
intended the scalar form get a clear error instead of a silent no-op.
179+
"""
180+
with pytest.raises(ValueError, match="Quadrants classes"):
181+
qd.precise(qd.Vector([1.0, 2.0]))
182+
with pytest.raises(ValueError, match="Quadrants classes"):
183+
qd.precise(qd.Matrix([[1.0, 2.0], [3.0, 4.0]]))
184+
185+
186+
@test_utils.test(default_fp=qd.f32, fast_math=True)
187+
def test_qd_precise_recurses_through_select():
188+
"""The walker must descend through `qd.select` (TernaryOp) so inner binary ops get tagged.
189+
190+
Observable via the signed-zero rule: alg_simp rewrites `x + 0.0 -> x` unconditionally unless the add
191+
is tagged `precise`. When the add lives inside a `qd.select(...)` wrapped by `qd.precise`, the walker
192+
must reach it for the rewrite to be skipped -- at which point IEEE arithmetic delivers
193+
`(-0.0) + 0.0 = +0.0`. Without the tag, alg_simp strips the add and `-0.0` survives.
194+
"""
195+
196+
@qd.kernel
197+
def k(x: qd.types.ndarray(qd.f32, ndim=1), out: qd.types.ndarray(qd.f32, ndim=1)):
198+
# `x[0]` is a runtime load, so neither operand reduces to a compile-time constant and the
199+
# ConstantFold pass cannot pre-compute the add. alg_simp's `a + 0 -> a` still matches.
200+
zero = qd.f32(0.0)
201+
# Without qd.precise wrap, alg_simp strips the add, leaving `x[0]` itself: bit pattern 0x80000000.
202+
out[0] = qd.select(qd.i32(1), x[0] + zero, zero)
203+
# With qd.precise wrap, the walker must recurse through the select and tag the inner add;
204+
# alg_simp then skips the fold, and IEEE `(-0.0) + 0.0` yields `+0.0`: bit pattern 0x00000000.
205+
out[1] = qd.precise(qd.select(qd.i32(1), x[0] + zero, zero))
206+
207+
x_in = qd.ndarray(dtype=qd.f32, shape=(1,))
208+
x_in.from_numpy(np.array([-0.0], dtype=np.float32))
209+
out = qd.ndarray(dtype=qd.f32, shape=(2,))
210+
k(x_in, out)
211+
naive_bits, precise_bits = (int(v.view(np.uint32)) for v in out.to_numpy())
212+
assert naive_bits == 0x80000000, (
213+
f"Expected alg_simp to strip the unprotected `-0.0 + 0.0`, leaving bit pattern 0x80000000, "
214+
f"got 0x{naive_bits:08x}."
215+
)
216+
assert precise_bits == 0x00000000, (
217+
f"Expected `qd.precise(select(..., -0.0 + 0.0, ...))` to recurse through the select, tag the inner "
218+
f"add, and let IEEE collapse `-0.0 + 0.0` to `+0.0` (bit pattern 0x00000000); got 0x{precise_bits:08x}. "
219+
f"The walker may not be descending through TernaryOp."
220+
)
221+
222+
223+
# NOTE: a behavioral test for the `pow` precise-bail (alg_simp.cpp:463) is deliberately omitted. The
224+
# rewrites `a**1 -> a`, `a**0 -> 1`, `a**0.5 -> sqrt(a)`, and `a**n -> (a*a)...` are all IEEE-equivalent to
225+
# the original `pow()` call on the inputs exposed by any plain-pytest kernel, so there is no observable
226+
# difference between `qd.precise(x ** n)` and `x ** n` at runtime today. The gate remains valuable as
227+
# future-proofing (keeps the synthesized mul/div/sqrt chain tagged consistently with what the user wrote).

0 commit comments

Comments
 (0)