Skip to content

Commit 99a5e31

Browse files
hmgaudeckerclaude
andauthored
Pin user-supplied floats to canonical dtype at every API boundary (#345)
Co-authored-by: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
1 parent bc4d5b3 commit 99a5e31

34 files changed

Lines changed: 1086 additions & 332 deletions

benchmarks/bench_aca_baseline.py

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -47,14 +47,30 @@
4747

4848

4949
def _build() -> tuple[object, object, object]:
50-
"""Build the aca-baseline model, params, and initial conditions."""
50+
"""Build the aca-baseline model, params, and initial conditions.
51+
52+
aca_model and lcm imports are deferred to the function body — ASV's
53+
forkserver runs `preimport` to discover benchmarks across every
54+
`bench_*.py` module before forking workers. Importing JAX at module
55+
top loads the multithreaded XLA backend into the forkserver; every
56+
subsequent `os.fork()` inherits a corrupted CUDA context and the
57+
first device op in the worker aborts with
58+
`CUDA_ERROR_NOT_INITIALIZED`. Per-call imports keep JAX out of the
59+
forkserver and confine it to the worker process.
60+
"""
61+
from aca_model.agent.preferences import BenchmarkPrefType
5162
from aca_model.benchmark import (
5263
create_benchmark_model,
5364
get_benchmark_initial_conditions,
5465
get_benchmark_params,
5566
)
5667

57-
model = create_benchmark_model(n_subjects=_N_SUBJECTS)
68+
from lcm import DiscreteGrid
69+
70+
model = create_benchmark_model(
71+
n_subjects=_N_SUBJECTS,
72+
pref_type_grid=DiscreteGrid(BenchmarkPrefType),
73+
)
5874
_, model_params = get_benchmark_params(model=model)
5975
initial_conditions = get_benchmark_initial_conditions(
6076
model=model, n_subjects=_N_SUBJECTS, seed=0

pixi.lock

Lines changed: 4 additions & 4 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

pyproject.toml

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,7 @@ tests-cuda13 = { features = [ "tests", "cuda13" ], solve-group = "cuda13" }
9898
tests-metal = { features = [ "tests", "metal" ], solve-group = "metal" }
9999
type-checking = { features = [ "type-checking", "tests" ], solve-group = "default" }
100100
[tool.pixi.feature.benchmarks.pypi-dependencies]
101-
aca-model = { git = "https://github.com/OpenSourceEconomics/aca-model.git", rev = "f09b5e34102ff42f739b95be5a9d388795b734a1" }
101+
aca-model = { git = "https://github.com/OpenSourceEconomics/aca-model.git", rev = "9ac20430f499a8b1cdb056af85bc2a26e850bad2" }
102102
[tool.pixi.feature.cuda12]
103103
platforms = [ "linux-64" ]
104104
system-requirements = { cuda = "12" }
@@ -242,6 +242,15 @@ per-file-ignores."tests/*" = [
242242
"S301", # Use of pickle
243243
"SLF001", # Private member access
244244
]
245+
per-file-ignores."tests/test_dtypes.py" = [
246+
"ARG001", # Unused function argument (x64_enabled / x64_disabled fixtures)
247+
]
248+
per-file-ignores."tests/test_explicit_dtype_filter.py" = [
249+
"ARG001", # Unused function argument (x64_disabled fixture)
250+
]
251+
per-file-ignores."tests/test_float_dtype_invariants.py" = [
252+
"ARG001", # Unused function argument (x64_disabled fixture)
253+
]
245254
per-file-ignores."tests/test_next_state.py" = [
246255
"ARG001", # Unused function argument
247256
"ARG005", # Unused lambda argument
@@ -294,7 +303,15 @@ ini_options.addopts = [
294303
"--dist",
295304
"loadfile",
296305
]
297-
ini_options.filterwarnings = []
306+
ini_options.filterwarnings = [
307+
# JAX emits this UserWarning when user code asks for a dtype wider
308+
# than the active x64 setting allows. Under `--precision=32` it
309+
# surfaces every stray `jnp.int64` / `jnp.float64` / `dtype=int64`
310+
# literal in src/ — the only files that legitimately trigger it are
311+
# the dtype-invariant test modules, which opt out via a local
312+
# `pytestmark` filter.
313+
"error:Explicitly requested dtype.*:UserWarning",
314+
]
298315
ini_options.markers = [
299316
"illustrative: Tests are designed for illustrative purposes",
300317
"gpu: Tests that require a GPU (skipped on CPU-only machines)",

src/lcm/ages.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
import jax.numpy as jnp
1111

1212
from lcm.exceptions import GridInitializationError, format_messages
13-
from lcm.typing import Age, Float1D, Int1D
13+
from lcm.typing import Float1D, Int1D
1414

1515
STEP_UNITS: MappingProxyType[str, Fraction] = MappingProxyType(
1616
{
@@ -129,7 +129,7 @@ def exact_step_size(self) -> int | Fraction | None:
129129
"""
130130
return self._exact_step_size
131131

132-
def period_to_age(self, period: int) -> Age:
132+
def period_to_age(self, period: int) -> int | float:
133133
"""Convert a period index to the corresponding age.
134134
135135
Args:
@@ -151,7 +151,7 @@ def period_to_age(self, period: int) -> Age:
151151
return int(self._values[period])
152152
return float(self._values[period])
153153

154-
def age_to_period(self, age: Age) -> int:
154+
def age_to_period(self, age: float) -> int:
155155
"""Convert an age to the corresponding period index.
156156
157157
Args:
@@ -172,12 +172,14 @@ def age_to_period(self, age: Age) -> int:
172172
raise ValueError(msg) from None
173173

174174
@functools.cached_property
175-
def _age_to_period_map(self) -> dict[Age, int]:
175+
def _age_to_period_map(self) -> dict[int | float, int]:
176176
if self._is_integer:
177177
return {int(v): i for i, v in enumerate(self._exact_values)}
178178
return {float(v): i for i, v in enumerate(self._exact_values)}
179179

180-
def get_periods_where(self, predicate: Callable[[Age], bool]) -> tuple[int, ...]:
180+
def get_periods_where(
181+
self, predicate: Callable[[int | float], bool]
182+
) -> tuple[int, ...]:
181183
"""Get period indices where predicate is True.
182184
183185
Args:
@@ -187,7 +189,7 @@ def get_periods_where(self, predicate: Callable[[Age], bool]) -> tuple[int, ...]
187189
Tuple of period indices where predicate(age) is True.
188190
189191
"""
190-
_convert: Callable[[object], Age] = int if self._is_integer else float # ty: ignore[invalid-assignment]
192+
_convert: Callable[[object], int | float] = int if self._is_integer else float # ty: ignore[invalid-assignment]
191193
return tuple(
192194
period
193195
for period in range(self.n_periods)

src/lcm/dtypes.py

Lines changed: 54 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3,22 +3,32 @@
33
Used at every API boundary that accepts user data (params, initial
44
conditions, regime-id arrays) — always called from Python, never inside
55
JIT. Each helper validates that the value fits the target dtype and
6-
raises a clearly-named error if not.
7-
8-
Casts further down the simulate stack (e.g. transition outputs landing
9-
in the state pool) use plain `.astype` and rely on the boundary cast
10-
above them having already pinned the canonical dtype.
6+
raises a clearly-named error if not. Once an input has crossed the
7+
boundary it carries the canonical dtype unchanged through the simulate
8+
stack; downstream code does not re-cast.
119
"""
1210

11+
import jax
1312
import jax.numpy as jnp
1413
import numpy as np
1514
from jax import Array
1615

1716
_INT32_MIN = int(np.iinfo(np.int32).min)
1817
_INT32_MAX = int(np.iinfo(np.int32).max)
18+
_FLOAT32_MAX = float(np.finfo(np.float32).max)
19+
20+
21+
def canonical_float_dtype() -> jnp.dtype:
22+
"""Return pylcm's canonical float dtype, derived from `jax_enable_x64`.
23+
24+
Returns `jnp.float64` if `jax.config.jax_enable_x64` is True,
25+
otherwise `jnp.float32`. The value is read at call time, not at
26+
import, so toggling the JAX config (e.g. between tests) is honoured.
27+
"""
28+
return jnp.float64 if jax.config.read("jax_enable_x64") else jnp.float32
1929

2030

21-
def safe_to_int32(value: object, *, name: str) -> Array:
31+
def safe_to_int_dtype(value: object, *, name: str) -> Array:
2232
"""Cast a scalar, sequence, or array to `jnp.int32`, checking int32 range.
2333
2434
Args:
@@ -46,3 +56,41 @@ def safe_to_int32(value: object, *, name: str) -> Array:
4656
)
4757
raise ValueError(msg)
4858
return jnp.asarray(np_value, dtype=jnp.int32)
59+
60+
61+
def safe_to_float_dtype(value: object, *, name: str) -> Array:
62+
"""Cast a scalar, sequence, or array to the canonical float dtype.
63+
64+
Range check fires only on a down-cast:
65+
66+
- Down-cast (float64 → float32 under `jax_enable_x64=False`): raise
67+
`OverflowError` if any element exceeds float32 magnitude rather
68+
than letting JAX silently saturate to ``±inf``.
69+
- Up-cast or same-width cast: skip the range check. Precision loss
70+
within range is not an error — it is an inherent consequence of
71+
`jax_enable_x64=False`.
72+
73+
Args:
74+
value: A Python float, numpy/JAX scalar, or array-like.
75+
name: Qualified name of the leaf — surfaced in the error message.
76+
77+
Returns:
78+
A JAX array at `canonical_float_dtype()` (0-d if `value` was a
79+
scalar).
80+
81+
Raises:
82+
OverflowError: If down-casting to `float32` would saturate any
83+
element to `±inf`. The message names the leaf via `name`.
84+
85+
"""
86+
target_dtype = canonical_float_dtype()
87+
np_value = np.asarray(value)
88+
if target_dtype == jnp.float32 and np_value.size > 0:
89+
max_mag = float(np.max(np.abs(np_value)))
90+
if max_mag > _FLOAT32_MAX:
91+
msg = (
92+
f"{name}: float32 overflow — max |value| {max_mag:g} "
93+
f"exceeds float32 max {_FLOAT32_MAX:g}."
94+
)
95+
raise OverflowError(msg)
96+
return jnp.asarray(np_value, dtype=target_dtype)

0 commit comments

Comments
 (0)