Skip to content

Commit 303df59

Browse files
hmgaudeckerclaude
andauthored
Auto-convert pd.Series in fixed_params (#308)
Co-authored-by: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent d5fe32c commit 303df59

9 files changed

Lines changed: 918 additions & 361 deletions

File tree

AGENTS.md

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -222,10 +222,12 @@ Model(
222222

223223
### Derived Categoricals
224224

225-
When `solve()` / `simulate()` parameters are indexed by a DAG function output (not a
226-
model state/action), pass `derived_categoricals={"name": DiscreteGrid(...)}`. Functions
227-
used as derived categoricals must return **integer** types, not booleans — JAX cannot
228-
use booleans as array indices inside JIT. Use `jnp.int32(...)` to cast.
225+
When parameters are indexed by a DAG function output (not a model state/action), declare
226+
`derived_categoricals={"name": DiscreteGrid(CategoryClass)}` on the `Regime` that uses
227+
it. For convenience, model-level `derived_categoricals` on `Model(...)` are broadcast to
228+
all regimes. Functions used as derived categoricals must return **integer** types, not
229+
booleans — JAX cannot use booleans as array indices inside JIT. Use `jnp.int32(...)` to
230+
cast.
229231

230232
### SimulationResult
231233

docs/user_guide/pandas_interop.md

Lines changed: 25 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -118,29 +118,41 @@ validate labels against. You will see an error like:
118118
```
119119
Unrecognised indexing parameter 'employment_type'. Expected 'age' or a
120120
discrete grid name (['health', 'partner']). If 'employment_type' is a DAG
121-
function output, pass derived_categoricals={"employment_type": DiscreteGrid(...)}
122-
to solve() / simulate().
121+
function output, add derived_categoricals={"employment_type": DiscreteGrid(EmploymentType)}
122+
to the Regime or Model constructor.
123123
```
124124

125-
Fix this by passing the missing grid explicitly:
125+
Fix this by declaring the grid on the `Regime` that uses it:
126126

127127
```python
128-
model.solve(
129-
params=params,
128+
working = Regime(
129+
# ... other fields ...
130130
derived_categoricals={"employment_type": DiscreteGrid(EmploymentType)},
131131
)
132132
```
133133

134-
If the variable has different categories in different regimes, pass a per-regime
135-
mapping:
134+
If the variable has different categories in different regimes, each regime declares its
135+
own grid:
136136

137137
```python
138-
derived_categoricals = {
139-
"employment_type": {
140-
"working": DiscreteGrid(FullEmploymentType),
141-
"retired": DiscreteGrid(RetiredEmploymentType),
142-
},
143-
}
138+
working = Regime(
139+
# ... other fields ...
140+
derived_categoricals={"employment_type": DiscreteGrid(FullEmploymentType)},
141+
)
142+
retired = Regime(
143+
# ... other fields ...
144+
derived_categoricals={"employment_type": DiscreteGrid(RetiredEmploymentType)},
145+
)
146+
```
147+
148+
For convenience, model-level `derived_categoricals` are broadcast to all regimes:
149+
150+
```python
151+
Model(
152+
regimes={"working": working, "retired": retired},
153+
derived_categoricals={"employment_type": DiscreteGrid(EmploymentType)},
154+
# ... other fields ...
155+
)
144156
```
145157

146158
### Integer return types required

src/lcm/model.py

Lines changed: 68 additions & 92 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
"""Collection of classes that are used by the user to define the model and grids."""
22

3+
import dataclasses
34
from collections.abc import Mapping
45
from pathlib import Path
56
from types import MappingProxyType
@@ -8,9 +9,10 @@
89
from jax import Array
910

1011
from lcm.ages import AgeGrid
11-
from lcm.exceptions import InvalidParamsError
12+
from lcm.exceptions import ModelInitializationError
1213
from lcm.grids import DiscreteGrid
1314
from lcm.model_processing import (
15+
_validate_param_types,
1416
build_regimes_and_template,
1517
validate_model_inputs,
1618
)
@@ -19,7 +21,6 @@
1921
has_series,
2022
initial_conditions_from_dataframe,
2123
)
22-
from lcm.params import MappingLeaf, SequenceLeaf
2324
from lcm.params.processing import (
2425
process_params,
2526
)
@@ -78,7 +79,7 @@ class Model:
7879
"""Immutable mapping of regime names to internal regime instances."""
7980

8081
enable_jit: bool = True
81-
"""Whether to JIT-compile the functions of the internal regime."""
82+
"""Whether to JIT-compile the functions of the internal regimes."""
8283

8384
fixed_params: UserParams
8485
"""Parameters fixed at model initialization."""
@@ -95,6 +96,7 @@ def __init__(
9596
regime_id_class: type,
9697
enable_jit: bool = True,
9798
fixed_params: UserParams = MappingProxyType({}),
99+
derived_categoricals: Mapping[str, DiscreteGrid] = MappingProxyType({}),
98100
) -> None:
99101
"""Initialize the Model.
100102
@@ -103,8 +105,13 @@ def __init__(
103105
ages: Age grid for the model.
104106
description: Description of the model.
105107
regime_id_class: Dataclass mapping regime names to integer indices.
106-
enable_jit: Whether to jit the functions of the internal regime.
108+
enable_jit: Whether to JIT-compile the functions of the internal
109+
regimes.
107110
fixed_params: Parameters that can be fixed at model initialization.
111+
derived_categoricals: Categorical grids for DAG function outputs
112+
not in states/actions. Broadcast to all regimes (merged with
113+
each regime's own `derived_categoricals`). Raises if a regime
114+
already has a conflicting entry.
108115
109116
"""
110117
self.description = description
@@ -125,10 +132,10 @@ def __init__(
125132
)
126133
)
127134
)
128-
self.regimes = MappingProxyType(dict(regimes))
135+
self.regimes = _merge_derived_categoricals(regimes, derived_categoricals)
129136
self.internal_regimes, self._params_template = build_regimes_and_template(
130-
regimes=regimes,
131137
ages=self.ages,
138+
regimes=self.regimes,
132139
regime_names_to_ids=self.regime_names_to_ids,
133140
enable_jit=enable_jit,
134141
fixed_params=self.fixed_params,
@@ -162,8 +169,6 @@ def solve(
162169
self,
163170
*,
164171
params: UserParams,
165-
derived_categoricals: Mapping[str, DiscreteGrid | Mapping[str, DiscreteGrid]]
166-
| None = None,
167172
log_level: LogLevel = "progress",
168173
log_path: str | Path | None = None,
169174
log_keep_n_latest: int = 3,
@@ -181,10 +186,6 @@ def solve(
181186
specification
182187
Values may be `pd.Series` with labeled indices; they are
183188
auto-converted to JAX arrays.
184-
derived_categoricals: Extra categorical mappings (level name to
185-
`DiscreteGrid`) for derived variables not in the model's
186-
state/action grids. Pass per-regime mappings as
187-
`{"var": {"regime_a": grid_a, ...}}`.
188189
log_level: Logging verbosity. `"off"` suppresses output, `"warning"` shows
189190
NaN/Inf warnings, `"progress"` adds timing, `"debug"` adds stats and
190191
requires `log_path`.
@@ -197,13 +198,7 @@ def solve(
197198
198199
"""
199200
_validate_log_args(log_level=log_level, log_path=log_path)
200-
internal_params = process_params(
201-
params=params, params_template=self._params_template
202-
)
203-
internal_params = _maybe_convert_series(
204-
internal_params, model=self, derived_categoricals=derived_categoricals
205-
)
206-
_validate_param_types(internal_params)
201+
internal_params = self._process_params(params)
207202
validate_regime_transitions_all_periods(
208203
internal_regimes=self.internal_regimes,
209204
internal_params=internal_params,
@@ -229,8 +224,6 @@ def simulate(
229224
self,
230225
*,
231226
params: UserParams,
232-
derived_categoricals: Mapping[str, DiscreteGrid | Mapping[str, DiscreteGrid]]
233-
| None = None,
234227
initial_conditions: Mapping[str, Array],
235228
period_to_regime_to_V_arr: MappingProxyType[
236229
int, MappingProxyType[RegimeName, FloatND]
@@ -259,10 +252,6 @@ def simulate(
259252
specification
260253
Values may be `pd.Series` with labeled indices; they are
261254
auto-converted to JAX arrays.
262-
derived_categoricals: Extra categorical mappings (level name to
263-
`DiscreteGrid`) for derived variables not in the model's
264-
state/action grids. Pass per-regime mappings as
265-
`{"var": {"regime_a": grid_a, ...}}`.
266255
initial_conditions: Mapping of state names (plus `"regime"`) to arrays.
267256
All arrays must have the same length (number of subjects). The
268257
`"regime"` entry must contain integer regime codes (from
@@ -285,14 +274,13 @@ def simulate(
285274
286275
"""
287276
_validate_log_args(log_level=log_level, log_path=log_path)
288-
initial_conditions = _maybe_convert_dataframe(initial_conditions, model=self)
289-
internal_params = process_params(
290-
params=params, params_template=self._params_template
291-
)
292-
internal_params = _maybe_convert_series(
293-
internal_params, model=self, derived_categoricals=derived_categoricals
294-
)
295-
_validate_param_types(internal_params)
277+
if isinstance(initial_conditions, pd.DataFrame):
278+
initial_conditions = initial_conditions_from_dataframe(
279+
df=initial_conditions,
280+
regimes=self.regimes,
281+
regime_names_to_ids=self.regime_names_to_ids,
282+
)
283+
internal_params = self._process_params(params)
296284
if check_initial_conditions:
297285
validate_initial_conditions(
298286
initial_conditions=initial_conditions,
@@ -337,71 +325,59 @@ def simulate(
337325
)
338326
return result
339327

340-
341-
def _maybe_convert_series(
342-
internal_params: InternalParams,
343-
*,
344-
model: Model,
345-
derived_categoricals: Mapping[str, DiscreteGrid | Mapping[str, DiscreteGrid]]
346-
| None,
347-
) -> InternalParams:
348-
"""Convert pd.Series leaves in params to JAX arrays if any are present."""
349-
if derived_categoricals is not None or has_series(internal_params):
350-
return convert_series_in_params(
351-
internal_params=internal_params,
352-
model=model,
353-
derived_categoricals=derived_categoricals,
328+
def _process_params(self, params: UserParams) -> InternalParams:
329+
"""Broadcast, convert Series, and validate user params."""
330+
internal_params = process_params(
331+
params=params, params_template=self._params_template
354332
)
355-
return internal_params
333+
if has_series(internal_params):
334+
internal_params = convert_series_in_params(
335+
internal_params=internal_params,
336+
ages=self.ages,
337+
regimes=self.regimes,
338+
regime_names_to_ids=self.regime_names_to_ids,
339+
)
340+
_validate_param_types(internal_params)
341+
return internal_params
342+
356343

344+
def _merge_derived_categoricals(
345+
regimes: Mapping[str, Regime],
346+
derived_categoricals: Mapping[str, DiscreteGrid],
347+
) -> MappingProxyType[str, Regime]:
348+
"""Merge model-level derived_categoricals into each regime.
357349
358-
def _validate_param_types(internal_params: InternalParams) -> None:
359-
"""Raise if any param leaf is not a Python scalar or JAX array.
350+
Args:
351+
regimes: Mapping of regime names to Regime instances.
352+
derived_categoricals: Model-level categorical grids to broadcast.
353+
354+
Returns:
355+
Immutable mapping of regime names to (possibly updated) Regime instances.
356+
357+
Raises:
358+
ModelInitializationError: If a regime already has a conflicting entry
359+
(same key, different categories).
360360
361-
After processing, every leaf value (including inside MappingLeaf /
362-
SequenceLeaf containers) must be a Python scalar (float, int, bool) or a
363-
JAX array. Notably, numpy arrays and pandas Series are not accepted.
364361
"""
365-
for regime_name, regime_params in internal_params.items():
366-
for key, value in regime_params.items():
367-
_check_leaf(value, f"{regime_name}__{key}")
368-
369-
370-
def _check_leaf(value: object, path: str) -> None:
371-
"""Check a single leaf value, recursing into MappingLeaf/SequenceLeaf."""
372-
if isinstance(value, MappingLeaf):
373-
for k, v in value.data.items():
374-
_check_leaf(v, f"{path}.{k}")
375-
return
376-
if isinstance(value, SequenceLeaf):
377-
for i, v in enumerate(value.data):
378-
_check_leaf(v, f"{path}[{i}]")
379-
return
380-
if isinstance(value, (float, int, bool)):
381-
return
382-
if hasattr(value, "dtype") and hasattr(value, "shape"):
383-
if isinstance(value, Array):
384-
return
385-
type_name = type(value).__module__ + "." + type(value).__name__
386-
msg = (
387-
f"Parameter '{path}' is a {type_name} (shape {value.shape}). "
388-
f"Use jnp.array() or pass a pd.Series with a named index."
362+
if not derived_categoricals:
363+
return MappingProxyType(dict(regimes))
364+
result = {}
365+
for name, regime in regimes.items():
366+
merged = dict(regime.derived_categoricals)
367+
for var, grid in derived_categoricals.items():
368+
existing = merged.get(var)
369+
if existing is not None and existing.categories != grid.categories:
370+
msg = (
371+
f"Model-level derived_categoricals['{var}'] conflicts "
372+
f"with regime '{name}': {grid.categories} vs "
373+
f"{existing.categories}."
374+
)
375+
raise ModelInitializationError(msg)
376+
merged[var] = grid
377+
result[name] = dataclasses.replace(
378+
regime, derived_categoricals=MappingProxyType(merged)
389379
)
390-
raise InvalidParamsError(msg)
391-
type_name = type(value).__module__ + "." + type(value).__name__
392-
msg = f"Parameter '{path}' has unexpected type {type_name}."
393-
raise InvalidParamsError(msg)
394-
395-
396-
def _maybe_convert_dataframe(
397-
initial_conditions: Mapping[str, Array],
398-
*,
399-
model: Model,
400-
) -> Mapping[str, Array]:
401-
"""Convert a DataFrame to initial_conditions dict if needed."""
402-
if isinstance(initial_conditions, pd.DataFrame):
403-
return initial_conditions_from_dataframe(df=initial_conditions, model=model)
404-
return initial_conditions
380+
return MappingProxyType(result)
405381

406382

407383
def _validate_log_args(*, log_level: LogLevel, log_path: str | Path | None) -> None:

0 commit comments

Comments
 (0)