Skip to content

Commit 7d7d2a0

Browse files
hmgaudeckerclaude
andcommitted
pytest: promote 'Explicitly requested dtype' UserWarning to an error
JAX silently truncates `jnp.int64` / `jnp.float64` requests under `jax_enable_x64=False` and emits a `UserWarning`. The default test config (`filterwarnings = []`) let those warnings pass — a stray `int64` literal in src/ would slip through CI as a warning the operator would have to spot by eye. Switch the filter to `error:Explicitly requested dtype.*:UserWarning`. Combined with the existing `--precision=32` job (`tests-32bit`), every wide-dtype literal in src/ now fails the suite. The three dtype-invariant test modules (`test_int_dtype_invariants`, `test_float_dtype_invariants`, `test_dtypes`) opt back to the warning default via a module-level `pytestmark` — they exist to *exercise* the cast at the barrier and legitimately pass `int64` / `float64` inputs. Add `tests/test_explicit_dtype_filter.py` with two tests confirming the filter is in effect: each requests a wide dtype and asserts the warning surfaces as `UserWarning`. Addresses the review on #340 without the false-positive surface of a literal-string grep. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
1 parent 7547ac3 commit 7d7d2a0

5 files changed

Lines changed: 59 additions & 1 deletion

File tree

pyproject.toml

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -245,6 +245,9 @@ per-file-ignores."tests/*" = [
245245
per-file-ignores."tests/test_dtypes.py" = [
246246
"ARG001", # Unused function argument (x64_enabled / x64_disabled fixtures)
247247
]
248+
per-file-ignores."tests/test_explicit_dtype_filter.py" = [
249+
"ARG001", # Unused function argument (x64_disabled fixture)
250+
]
248251
per-file-ignores."tests/test_float_dtype_invariants.py" = [
249252
"ARG001", # Unused function argument (x64_disabled fixture)
250253
]
@@ -300,7 +303,15 @@ ini_options.addopts = [
300303
"--dist",
301304
"loadfile",
302305
]
303-
ini_options.filterwarnings = []
306+
ini_options.filterwarnings = [
307+
# JAX emits this UserWarning when user code asks for a dtype wider
308+
# than the active x64 setting allows. Under `--precision=32` it
309+
# surfaces every stray `jnp.int64` / `jnp.float64` / `dtype=int64`
310+
# literal in src/ — the only files that legitimately trigger it are
311+
# the dtype-invariant test modules, which opt out via a local
312+
# `pytestmark` filter.
313+
"error:Explicitly requested dtype.*:UserWarning",
314+
]
304315
ini_options.markers = [
305316
"illustrative: Tests are designed for illustrative purposes",
306317
"gpu: Tests that require a GPU (skipped on CPU-only machines)",

tests/test_dtypes.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,14 @@
66

77
from lcm.dtypes import canonical_float_dtype, safe_to_float_dtype, safe_to_int_dtype
88

9+
# Several tests here pass `int64` / `float64` inputs to verify the
10+
# barrier helpers cast them correctly. Re-allow the JAX truncation
11+
# warning that the project-wide filter (see `pyproject.toml`) promotes
12+
# to an error — the legitimate trigger lives here.
13+
pytestmark = pytest.mark.filterwarnings(
14+
"default:Explicitly requested dtype.*:UserWarning"
15+
)
16+
917

1018
@pytest.mark.parametrize(
1119
"value",
Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
"""Pyproject's `filterwarnings` rule promotes the JAX truncation warning to an error.
2+
3+
The rule fires when source code asks for a dtype wider than the active
4+
`jax_enable_x64` setting permits — under `--precision=32`, every stray
5+
`jnp.int64` / `jnp.float64` request in `src/` becomes a test failure.
6+
The dtype-invariant test modules opt back to `default` for the same
7+
warning because they exist to *exercise* the cast at the barrier.
8+
"""
9+
10+
import jax.numpy as jnp
11+
import pytest
12+
13+
14+
def test_float64_request_under_no_x64_raises(x64_disabled: None):
15+
"""A `jnp.float64` literal under `jax_enable_x64=False` is promoted to an error."""
16+
with pytest.raises(UserWarning, match="Explicitly requested dtype float64"):
17+
jnp.asarray([1.0, 2.0], dtype=jnp.float64)
18+
19+
20+
def test_int64_request_under_no_x64_raises(x64_disabled: None):
21+
"""A `jnp.int64` literal under `jax_enable_x64=False` is promoted to an error."""
22+
with pytest.raises(UserWarning, match="Explicitly requested dtype int64"):
23+
jnp.asarray([1, 2, 3], dtype=jnp.int64)

tests/test_float_dtype_invariants.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,14 @@
1919
get_params,
2020
)
2121

22+
# These tests deliberately pass `float64` inputs to verify the cast at
23+
# the barrier. Re-allow the JAX truncation warning that the
24+
# project-wide filter (see `pyproject.toml`) promotes to an error —
25+
# the legitimate trigger lives here.
26+
pytestmark = pytest.mark.filterwarnings(
27+
"default:Explicitly requested dtype.*:UserWarning"
28+
)
29+
2230

2331
def test_build_initial_states_casts_user_float64_to_canonical(x64_disabled: None):
2432
"""A float64 continuous initial state lands at `canonical_float_dtype()`."""

tests/test_int_dtype_invariants.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,14 @@
2525
working_life,
2626
)
2727

28+
# These tests deliberately pass `int64` inputs to verify the cast at
29+
# the barrier. Re-allow the JAX truncation warning that the
30+
# project-wide filter (see `pyproject.toml`) promotes to an error —
31+
# the legitimate trigger lives here.
32+
pytestmark = pytest.mark.filterwarnings(
33+
"default:Explicitly requested dtype.*:UserWarning"
34+
)
35+
2836

2937
def test_discrete_grid_to_jax_is_int32() -> None:
3038
"""Every `DiscreteGrid.to_jax()` in the model returns an `int32` array."""

0 commit comments

Comments
 (0)