Skip to content

Commit 8e95dd9

Browse files
hmgaudeckertimmens
andauthored
Stricter param validation: reject numpy arrays, validate Series labels (#302)
Co-authored-by: timmens <mensingertim@gmail.com>
1 parent 4a512e8 commit 8e95dd9

6 files changed

Lines changed: 180 additions & 8 deletions

File tree

AGENTS.md

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -220,6 +220,13 @@ Model(
220220
\- Simulate forward given solution. `period_to_regime_to_V_arr` is optional; when
221221
`None`, the model is solved automatically before simulating.
222222

223+
### Derived Categoricals
224+
225+
When `solve()` / `simulate()` parameters are indexed by a DAG function output (not a
226+
model state/action), pass `derived_categoricals={"name": DiscreteGrid(...)}`. Functions
227+
used as derived categoricals must return **integer** types, not booleans — JAX cannot
228+
use booleans as array indices inside JIT. Use `jnp.int32(...)` to cast.
229+
223230
### SimulationResult
224231

225232
`simulate()` returns a `SimulationResult` object:

docs/user_guide/pandas_interop.md

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -143,6 +143,23 @@ derived_categoricals = {
143143
}
144144
```
145145

146+
### Integer return types required
147+
148+
Functions used as derived categoricals must return **integer** values, not booleans. JAX
149+
cannot use boolean values as array indices inside JIT-compiled code
150+
(`NonConcreteBooleanIndexError`). If your derived categorical compares states:
151+
152+
```python
153+
# Wrong — returns bool, fails inside JIT
154+
def is_good_health(health: DiscreteState) -> BoolND:
155+
return health == Health.good
156+
157+
158+
# Correct — returns int32
159+
def is_good_health(health: DiscreteState) -> IntND:
160+
return jnp.int32(health == Health.good)
161+
```
162+
146163
## Validating Transition Probabilities
147164

148165
Check that a transition probability array has the correct shape, values in $[0, 1]$, and

src/lcm/model.py

Lines changed: 48 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4,14 +4,22 @@
44
from pathlib import Path
55
from types import MappingProxyType
66

7+
import pandas as pd
78
from jax import Array
89

910
from lcm.ages import AgeGrid
11+
from lcm.exceptions import InvalidParamsError
1012
from lcm.grids import DiscreteGrid
1113
from lcm.model_processing import (
1214
build_regimes_and_template,
1315
validate_model_inputs,
1416
)
17+
from lcm.pandas_utils import (
18+
convert_series_in_params,
19+
has_series,
20+
initial_conditions_from_dataframe,
21+
)
22+
from lcm.params import MappingLeaf, SequenceLeaf
1523
from lcm.params.processing import (
1624
process_params,
1725
)
@@ -195,6 +203,7 @@ def solve(
195203
internal_params = _maybe_convert_series(
196204
internal_params, model=self, derived_categoricals=derived_categoricals
197205
)
206+
_validate_param_types(internal_params)
198207
validate_regime_transitions_all_periods(
199208
internal_regimes=self.internal_regimes,
200209
internal_params=internal_params,
@@ -283,6 +292,7 @@ def simulate(
283292
internal_params = _maybe_convert_series(
284293
internal_params, model=self, derived_categoricals=derived_categoricals
285294
)
295+
_validate_param_types(internal_params)
286296
if check_initial_conditions:
287297
validate_initial_conditions(
288298
initial_conditions=initial_conditions,
@@ -336,8 +346,6 @@ def _maybe_convert_series(
336346
| None,
337347
) -> InternalParams:
338348
"""Convert pd.Series leaves in params to JAX arrays if any are present."""
339-
from lcm.pandas_utils import convert_series_in_params, has_series # noqa: PLC0415
340-
341349
if derived_categoricals is not None or has_series(internal_params):
342350
return convert_series_in_params(
343351
internal_params=internal_params,
@@ -347,17 +355,51 @@ def _maybe_convert_series(
347355
return internal_params
348356

349357

358+
def _validate_param_types(internal_params: InternalParams) -> None:
359+
"""Raise if any param leaf is not a Python scalar or JAX array.
360+
361+
After processing, every leaf value (including inside MappingLeaf /
362+
SequenceLeaf containers) must be a Python scalar (float, int, bool) or a
363+
JAX array. Notably, numpy arrays and pandas Series are not accepted.
364+
"""
365+
for regime_name, regime_params in internal_params.items():
366+
for key, value in regime_params.items():
367+
_check_leaf(value, f"{regime_name}__{key}")
368+
369+
370+
def _check_leaf(value: object, path: str) -> None:
371+
"""Check a single leaf value, recursing into MappingLeaf/SequenceLeaf."""
372+
if isinstance(value, MappingLeaf):
373+
for k, v in value.data.items():
374+
_check_leaf(v, f"{path}.{k}")
375+
return
376+
if isinstance(value, SequenceLeaf):
377+
for i, v in enumerate(value.data):
378+
_check_leaf(v, f"{path}[{i}]")
379+
return
380+
if isinstance(value, (float, int, bool)):
381+
return
382+
if hasattr(value, "dtype") and hasattr(value, "shape"):
383+
if isinstance(value, Array):
384+
return
385+
type_name = type(value).__module__ + "." + type(value).__name__
386+
msg = (
387+
f"Parameter '{path}' is a {type_name} (shape {value.shape}). "
388+
f"Use jax.numpy.array() or pass a pd.Series with a named index."
389+
)
390+
raise InvalidParamsError(msg)
391+
type_name = type(value).__module__ + "." + type(value).__name__
392+
msg = f"Parameter '{path}' has unexpected type {type_name}."
393+
raise InvalidParamsError(msg)
394+
395+
350396
def _maybe_convert_dataframe(
351397
initial_conditions: Mapping[str, Array],
352398
*,
353399
model: Model,
354400
) -> Mapping[str, Array]:
355401
"""Convert a DataFrame to initial_conditions dict if needed."""
356-
import pandas as pd # noqa: PLC0415
357-
358402
if isinstance(initial_conditions, pd.DataFrame):
359-
from lcm.pandas_utils import initial_conditions_from_dataframe # noqa: PLC0415
360-
361403
return initial_conditions_from_dataframe(df=initial_conditions, model=model)
362404
return initial_conditions
363405

src/lcm/pandas_utils.py

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from collections.abc import Callable, Mapping
44
from dataclasses import dataclass
55
from types import MappingProxyType
6-
from typing import cast
6+
from typing import TYPE_CHECKING, cast
77

88
import jax.numpy as jnp
99
import numpy as np
@@ -13,8 +13,11 @@
1313

1414
from lcm.ages import AgeGrid
1515
from lcm.grids import DiscreteGrid, IrregSpacedGrid
16-
from lcm.model import Model
1716
from lcm.params import MappingLeaf
17+
18+
if TYPE_CHECKING:
19+
from lcm.model import Model # avoid circular import: pandas_utils ↔ model
20+
1821
from lcm.params.sequence_leaf import SequenceLeaf
1922
from lcm.regime import Regime
2023
from lcm.shocks import _ShockGrid
@@ -726,6 +729,20 @@ def _map_level(*, mapping: _LevelMapping, level_values: pd.Index) -> np.ndarray:
726729
ValueError: If any label is not valid for the mapping.
727730
728731
"""
732+
# Categorical levels must use string labels matching grid category names.
733+
# Reject integer labels early with a clear message instead of a cryptic KeyError.
734+
if mapping.valid_labels and any(not isinstance(v, str) for v in level_values):
735+
non_str_types = sorted(
736+
{type(v).__name__ for v in level_values if not isinstance(v, str)}
737+
)
738+
msg = (
739+
f"Series index level '{mapping.name}' uses non-string labels "
740+
f"(types: {non_str_types}) but the DiscreteGrid expects string "
741+
f"category names. Use string labels matching: "
742+
f"{sorted(mapping.valid_labels)}."
743+
)
744+
raise ValueError(msg)
745+
729746
try:
730747
return np.array([mapping.label_to_index(v) for v in level_values])
731748
except ValueError:

tests/test_pandas_utils.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1097,6 +1097,27 @@ def test_array_from_series_wrong_level_names_raises() -> None:
10971097
)
10981098

10991099

1100+
def test_array_from_series_integer_labels_rejected() -> None:
1101+
"""Integer labels on a categorical level raise ValueError."""
1102+
model = get_stochastic_model(3)
1103+
# Use integer labels (0, 1) instead of string category names
1104+
index = pd.MultiIndex.from_tuples(
1105+
[(40.0, 0, "single", "single"), (40.0, 1, "single", "single")],
1106+
names=["age", "labor_supply", "partner", "next_partner"],
1107+
)
1108+
series = pd.Series([0.5, 0.5], index=index)
1109+
func = model.regimes["working_life"].get_all_functions()["next_partner"]
1110+
with pytest.raises(ValueError, match="non-string labels"):
1111+
array_from_series(
1112+
sr=series,
1113+
func=func,
1114+
param_name="probs_array",
1115+
func_name="next_partner",
1116+
model=model,
1117+
regime_name="working_life",
1118+
)
1119+
1120+
11001121
def test_array_from_series_scalar_param_explicit_lookup() -> None:
11011122
"""Scalar parameter with explicit func lookup returns 1D array."""
11021123
model = get_stochastic_model(3)

tests/test_validate_param_types.py

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
"""Test that numpy arrays in params are rejected after processing."""
2+
3+
import jax.numpy as jnp
4+
import numpy as np
5+
import pytest
6+
7+
from lcm import AgeGrid, DiscreteGrid, LinSpacedGrid, Model, Regime, categorical
8+
from lcm.exceptions import InvalidParamsError
9+
10+
11+
@categorical(ordered=True)
12+
class Health:
13+
bad: int
14+
good: int
15+
16+
17+
@categorical(ordered=False)
18+
class RegimeId:
19+
working: int
20+
dead: int
21+
22+
23+
def _next_regime() -> int:
24+
return RegimeId.dead
25+
26+
27+
working = Regime(
28+
transition=_next_regime,
29+
active=lambda age: age < 30,
30+
states={
31+
"health": DiscreteGrid(Health),
32+
"wealth": LinSpacedGrid(start=0, stop=100, n_points=5),
33+
},
34+
state_transitions={"health": None, "wealth": lambda wealth: wealth},
35+
functions={"utility": lambda wealth, health, bonus: wealth + health + bonus},
36+
)
37+
38+
dead = Regime(
39+
transition=None,
40+
functions={"utility": lambda: 0.0},
41+
)
42+
43+
44+
def _make_model() -> Model:
45+
return Model(
46+
regimes={"working": working, "dead": dead},
47+
ages=AgeGrid(start=25, stop=30, step="Y"),
48+
regime_id_class=RegimeId,
49+
)
50+
51+
52+
def test_numpy_array_param_rejected() -> None:
53+
"""Passing a numpy array as a param should raise InvalidParamsError."""
54+
model = _make_model()
55+
with pytest.raises(InvalidParamsError, match=r"numpy\.ndarray"):
56+
model.solve(params={"bonus": np.array(1.0), "discount_factor": 0.95}) # ty: ignore[invalid-argument-type]
57+
58+
59+
def test_jax_array_param_accepted() -> None:
60+
"""JAX arrays should be accepted."""
61+
model = _make_model()
62+
model.solve(params={"bonus": jnp.array(1.0), "discount_factor": 0.95})
63+
64+
65+
def test_python_scalar_param_accepted() -> None:
66+
"""Python scalars should be accepted."""
67+
model = _make_model()
68+
model.solve(params={"bonus": 1.0, "discount_factor": 0.95})

0 commit comments

Comments
 (0)