|
3 | 3 | Used at every API boundary that accepts user data (params, initial |
4 | 4 | conditions, regime-id arrays) — always called from Python, never inside |
5 | 5 | JIT. Each helper validates that the value fits the target dtype and |
6 | | -raises a clearly-named error if not. |
7 | | -
|
8 | | -Casts further down the simulate stack (e.g. transition outputs landing |
9 | | -in the state pool) use plain `.astype` and rely on the boundary cast |
10 | | -above them having already pinned the canonical dtype. |
| 6 | +raises a clearly-named error if not. Once an input has crossed the |
| 7 | +boundary it carries the canonical dtype unchanged through the simulate |
| 8 | +stack; downstream code does not re-cast. |
11 | 9 | """ |
12 | 10 |
|
| 11 | +import jax |
13 | 12 | import jax.numpy as jnp |
14 | 13 | import numpy as np |
15 | 14 | from jax import Array |
16 | 15 |
|
17 | 16 | _INT32_MIN = int(np.iinfo(np.int32).min) |
18 | 17 | _INT32_MAX = int(np.iinfo(np.int32).max) |
| 18 | +_FLOAT32_MAX = float(np.finfo(np.float32).max) |
| 19 | + |
| 20 | + |
| 21 | +def canonical_float_dtype() -> jnp.dtype: |
| 22 | + """Return pylcm's canonical float dtype, derived from `jax_enable_x64`. |
| 23 | +
|
| 24 | + Returns `jnp.float64` if `jax.config.jax_enable_x64` is True, |
| 25 | + otherwise `jnp.float32`. The value is read at call time, not at |
| 26 | + import, so toggling the JAX config (e.g. between tests) is honoured. |
| 27 | + """ |
| 28 | + return jnp.float64 if jax.config.read("jax_enable_x64") else jnp.float32 |
19 | 29 |
|
20 | 30 |
|
21 | | -def safe_to_int32(value: object, *, name: str) -> Array: |
| 31 | +def safe_to_int_dtype(value: object, *, name: str) -> Array: |
22 | 32 | """Cast a scalar, sequence, or array to `jnp.int32`, checking int32 range. |
23 | 33 |
|
24 | 34 | Args: |
@@ -46,3 +56,41 @@ def safe_to_int32(value: object, *, name: str) -> Array: |
46 | 56 | ) |
47 | 57 | raise ValueError(msg) |
48 | 58 | return jnp.asarray(np_value, dtype=jnp.int32) |
| 59 | + |
| 60 | + |
| 61 | +def safe_to_float_dtype(value: object, *, name: str) -> Array: |
| 62 | + """Cast a scalar, sequence, or array to the canonical float dtype. |
| 63 | +
|
| 64 | + Range check fires only on a down-cast: |
| 65 | +
|
| 66 | + - Down-cast (float64 → float32 under `jax_enable_x64=False`): raise |
| 67 | + `OverflowError` if any element exceeds float32 magnitude rather |
| 68 | + than letting JAX silently saturate to ``±inf``. |
| 69 | + - Up-cast or same-width cast: skip the range check. Precision loss |
| 70 | + within range is not an error — it is an inherent consequence of |
| 71 | + `jax_enable_x64=False`. |
| 72 | +
|
| 73 | + Args: |
| 74 | + value: A Python float, numpy/JAX scalar, or array-like. |
| 75 | + name: Qualified name of the leaf — surfaced in the error message. |
| 76 | +
|
| 77 | + Returns: |
| 78 | + A JAX array at `canonical_float_dtype()` (0-d if `value` was a |
| 79 | + scalar). |
| 80 | +
|
| 81 | + Raises: |
| 82 | + OverflowError: If down-casting to `float32` would saturate any |
| 83 | + element to `±inf`. The message names the leaf via `name`. |
| 84 | +
|
| 85 | + """ |
| 86 | + target_dtype = canonical_float_dtype() |
| 87 | + np_value = np.asarray(value) |
| 88 | + if target_dtype == jnp.float32 and np_value.size > 0: |
| 89 | + max_mag = float(np.max(np.abs(np_value))) |
| 90 | + if max_mag > _FLOAT32_MAX: |
| 91 | + msg = ( |
| 92 | + f"{name}: float32 overflow — max |value| {max_mag:g} " |
| 93 | + f"exceeds float32 max {_FLOAT32_MAX:g}." |
| 94 | + ) |
| 95 | + raise OverflowError(msg) |
| 96 | + return jnp.asarray(np_value, dtype=target_dtype) |
0 commit comments