|
8 | 8 | """ |
9 | 9 |
|
10 | 10 | import numpy as np |
| 11 | +import pytest |
11 | 12 |
|
12 | 13 | import quadrants as qd |
13 | 14 |
|
@@ -169,3 +170,58 @@ def k(x: qd.types.ndarray(qd.f32, ndim=1), out: qd.types.ndarray(qd.f32, ndim=2) |
169 | 170 | f"qd.precise(unary) deviated from the correctly-rounded f32 reference by {max_ulp:.2f} ULP. " |
170 | 171 | f"The unary precise tag is not reaching the codegen for at least one of sin/cos/log/sqrt." |
171 | 172 | ) |
| 173 | + |
| 174 | + |
| 175 | +@test_utils.test(default_fp=qd.f32) |
| 176 | +def test_qd_precise_rejects_quadrants_classes(): |
| 177 | + """`qd.precise` is a scalar primitive. Wrapping a `Vector` or `Matrix` must raise so that users who |
| 178 | + intended the scalar form get a clear error instead of a silent no-op. |
| 179 | + """ |
| 180 | + with pytest.raises(ValueError, match="Quadrants classes"): |
| 181 | + qd.precise(qd.Vector([1.0, 2.0])) |
| 182 | + with pytest.raises(ValueError, match="Quadrants classes"): |
| 183 | + qd.precise(qd.Matrix([[1.0, 2.0], [3.0, 4.0]])) |
| 184 | + |
| 185 | + |
| 186 | +@test_utils.test(default_fp=qd.f32, fast_math=True) |
| 187 | +def test_qd_precise_recurses_through_select(): |
| 188 | + """The walker must descend through `qd.select` (TernaryOp) so inner binary ops get tagged. |
| 189 | +
|
| 190 | + Observable via the signed-zero rule: alg_simp rewrites `x + 0.0 -> x` unconditionally unless the add |
| 191 | + is tagged `precise`. When the add lives inside a `qd.select(...)` wrapped by `qd.precise`, the walker |
| 192 | + must reach it for the rewrite to be skipped -- at which point IEEE arithmetic delivers |
| 193 | + `(-0.0) + 0.0 = +0.0`. Without the tag, alg_simp strips the add and `-0.0` survives. |
| 194 | + """ |
| 195 | + |
| 196 | + @qd.kernel |
| 197 | + def k(x: qd.types.ndarray(qd.f32, ndim=1), out: qd.types.ndarray(qd.f32, ndim=1)): |
| 198 | + # `x[0]` is a runtime load, so neither operand reduces to a compile-time constant and the |
| 199 | + # ConstantFold pass cannot pre-compute the add. alg_simp's `a + 0 -> a` still matches. |
| 200 | + zero = qd.f32(0.0) |
| 201 | + # Without qd.precise wrap, alg_simp strips the add, leaving `x[0]` itself: bit pattern 0x80000000. |
| 202 | + out[0] = qd.select(qd.i32(1), x[0] + zero, zero) |
| 203 | + # With qd.precise wrap, the walker must recurse through the select and tag the inner add; |
| 204 | + # alg_simp then skips the fold, and IEEE `(-0.0) + 0.0` yields `+0.0`: bit pattern 0x00000000. |
| 205 | + out[1] = qd.precise(qd.select(qd.i32(1), x[0] + zero, zero)) |
| 206 | + |
| 207 | + x_in = qd.ndarray(dtype=qd.f32, shape=(1,)) |
| 208 | + x_in.from_numpy(np.array([-0.0], dtype=np.float32)) |
| 209 | + out = qd.ndarray(dtype=qd.f32, shape=(2,)) |
| 210 | + k(x_in, out) |
| 211 | + naive_bits, precise_bits = (int(v.view(np.uint32)) for v in out.to_numpy()) |
| 212 | + assert naive_bits == 0x80000000, ( |
| 213 | + f"Expected alg_simp to strip the unprotected `-0.0 + 0.0`, leaving bit pattern 0x80000000, " |
| 214 | + f"got 0x{naive_bits:08x}." |
| 215 | + ) |
| 216 | + assert precise_bits == 0x00000000, ( |
| 217 | + f"Expected `qd.precise(select(..., -0.0 + 0.0, ...))` to recurse through the select, tag the inner " |
| 218 | + f"add, and let IEEE collapse `-0.0 + 0.0` to `+0.0` (bit pattern 0x00000000); got 0x{precise_bits:08x}. " |
| 219 | + f"The walker may not be descending through TernaryOp." |
| 220 | + ) |
| 221 | + |
| 222 | + |
| 223 | +# NOTE: a behavioral test for the `pow` precise-bail (alg_simp.cpp:463) is deliberately omitted. The |
| 224 | +# rewrites `a**1 -> a`, `a**0 -> 1`, `a**0.5 -> sqrt(a)`, and `a**n -> (a*a)...` are all IEEE-equivalent to |
| 225 | +# the original `pow()` call on the inputs exposed by any plain-pytest kernel, so there is no observable |
| 226 | +# difference between `qd.precise(x ** n)` and `x ** n` at runtime today. The gate remains valuable as |
| 227 | +# future-proofing (keeps the synthesized mul/div/sqrt chain tagged consistently with what the user wrote). |
0 commit comments