11"""Collection of classes that are used by the user to define the model and grids."""
22
3+ import dataclasses
34from collections .abc import Mapping
45from pathlib import Path
56from types import MappingProxyType
89from jax import Array
910
1011from lcm .ages import AgeGrid
11- from lcm .exceptions import InvalidParamsError
12+ from lcm .exceptions import ModelInitializationError
1213from lcm .grids import DiscreteGrid
1314from lcm .model_processing import (
15+ _validate_param_types ,
1416 build_regimes_and_template ,
1517 validate_model_inputs ,
1618)
1921 has_series ,
2022 initial_conditions_from_dataframe ,
2123)
22- from lcm .params import MappingLeaf , SequenceLeaf
2324from lcm .params .processing import (
2425 process_params ,
2526)
@@ -78,7 +79,7 @@ class Model:
7879 """Immutable mapping of regime names to internal regime instances."""
7980
8081 enable_jit : bool = True
81- """Whether to JIT-compile the functions of the internal regime ."""
82+ """Whether to JIT-compile the functions of the internal regimes ."""
8283
8384 fixed_params : UserParams
8485 """Parameters fixed at model initialization."""
@@ -95,6 +96,7 @@ def __init__(
9596 regime_id_class : type ,
9697 enable_jit : bool = True ,
9798 fixed_params : UserParams = MappingProxyType ({}),
99+ derived_categoricals : Mapping [str , DiscreteGrid ] = MappingProxyType ({}),
98100 ) -> None :
99101 """Initialize the Model.
100102
@@ -103,8 +105,13 @@ def __init__(
103105 ages: Age grid for the model.
104106 description: Description of the model.
105107 regime_id_class: Dataclass mapping regime names to integer indices.
106- enable_jit: Whether to jit the functions of the internal regime.
108+ enable_jit: Whether to JIT-compile the functions of the internal
109+ regimes.
107110 fixed_params: Parameters that can be fixed at model initialization.
111+ derived_categoricals: Categorical grids for DAG function outputs
112+ not in states/actions. Broadcast to all regimes (merged with
113+ each regime's own `derived_categoricals`). Raises if a regime
114+ already has a conflicting entry.
108115
109116 """
110117 self .description = description
@@ -125,10 +132,10 @@ def __init__(
125132 )
126133 )
127134 )
128- self .regimes = MappingProxyType ( dict ( regimes ) )
135+ self .regimes = _merge_derived_categoricals ( regimes , derived_categoricals )
129136 self .internal_regimes , self ._params_template = build_regimes_and_template (
130- regimes = regimes ,
131137 ages = self .ages ,
138+ regimes = self .regimes ,
132139 regime_names_to_ids = self .regime_names_to_ids ,
133140 enable_jit = enable_jit ,
134141 fixed_params = self .fixed_params ,
@@ -162,8 +169,6 @@ def solve(
162169 self ,
163170 * ,
164171 params : UserParams ,
165- derived_categoricals : Mapping [str , DiscreteGrid | Mapping [str , DiscreteGrid ]]
166- | None = None ,
167172 log_level : LogLevel = "progress" ,
168173 log_path : str | Path | None = None ,
169174 log_keep_n_latest : int = 3 ,
@@ -181,10 +186,6 @@ def solve(
181186 specification
182187 Values may be `pd.Series` with labeled indices; they are
183188 auto-converted to JAX arrays.
184- derived_categoricals: Extra categorical mappings (level name to
185- `DiscreteGrid`) for derived variables not in the model's
186- state/action grids. Pass per-regime mappings as
187- `{"var": {"regime_a": grid_a, ...}}`.
188189 log_level: Logging verbosity. `"off"` suppresses output, `"warning"` shows
189190 NaN/Inf warnings, `"progress"` adds timing, `"debug"` adds stats and
190191 requires `log_path`.
@@ -197,13 +198,7 @@ def solve(
197198
198199 """
199200 _validate_log_args (log_level = log_level , log_path = log_path )
200- internal_params = process_params (
201- params = params , params_template = self ._params_template
202- )
203- internal_params = _maybe_convert_series (
204- internal_params , model = self , derived_categoricals = derived_categoricals
205- )
206- _validate_param_types (internal_params )
201+ internal_params = self ._process_params (params )
207202 validate_regime_transitions_all_periods (
208203 internal_regimes = self .internal_regimes ,
209204 internal_params = internal_params ,
@@ -229,8 +224,6 @@ def simulate(
229224 self ,
230225 * ,
231226 params : UserParams ,
232- derived_categoricals : Mapping [str , DiscreteGrid | Mapping [str , DiscreteGrid ]]
233- | None = None ,
234227 initial_conditions : Mapping [str , Array ],
235228 period_to_regime_to_V_arr : MappingProxyType [
236229 int , MappingProxyType [RegimeName , FloatND ]
@@ -259,10 +252,6 @@ def simulate(
259252 specification
260253 Values may be `pd.Series` with labeled indices; they are
261254 auto-converted to JAX arrays.
262- derived_categoricals: Extra categorical mappings (level name to
263- `DiscreteGrid`) for derived variables not in the model's
264- state/action grids. Pass per-regime mappings as
265- `{"var": {"regime_a": grid_a, ...}}`.
266255 initial_conditions: Mapping of state names (plus `"regime"`) to arrays.
267256 All arrays must have the same length (number of subjects). The
268257 `"regime"` entry must contain integer regime codes (from
@@ -285,14 +274,13 @@ def simulate(
285274
286275 """
287276 _validate_log_args (log_level = log_level , log_path = log_path )
288- initial_conditions = _maybe_convert_dataframe (initial_conditions , model = self )
289- internal_params = process_params (
290- params = params , params_template = self ._params_template
291- )
292- internal_params = _maybe_convert_series (
293- internal_params , model = self , derived_categoricals = derived_categoricals
294- )
295- _validate_param_types (internal_params )
277+ if isinstance (initial_conditions , pd .DataFrame ):
278+ initial_conditions = initial_conditions_from_dataframe (
279+ df = initial_conditions ,
280+ regimes = self .regimes ,
281+ regime_names_to_ids = self .regime_names_to_ids ,
282+ )
283+ internal_params = self ._process_params (params )
296284 if check_initial_conditions :
297285 validate_initial_conditions (
298286 initial_conditions = initial_conditions ,
@@ -337,71 +325,59 @@ def simulate(
337325 )
338326 return result
339327
340-
341- def _maybe_convert_series (
342- internal_params : InternalParams ,
343- * ,
344- model : Model ,
345- derived_categoricals : Mapping [str , DiscreteGrid | Mapping [str , DiscreteGrid ]]
346- | None ,
347- ) -> InternalParams :
348- """Convert pd.Series leaves in params to JAX arrays if any are present."""
349- if derived_categoricals is not None or has_series (internal_params ):
350- return convert_series_in_params (
351- internal_params = internal_params ,
352- model = model ,
353- derived_categoricals = derived_categoricals ,
328+ def _process_params (self , params : UserParams ) -> InternalParams :
329+ """Broadcast, convert Series, and validate user params."""
330+ internal_params = process_params (
331+ params = params , params_template = self ._params_template
354332 )
355- return internal_params
333+ if has_series (internal_params ):
334+ internal_params = convert_series_in_params (
335+ internal_params = internal_params ,
336+ ages = self .ages ,
337+ regimes = self .regimes ,
338+ regime_names_to_ids = self .regime_names_to_ids ,
339+ )
340+ _validate_param_types (internal_params )
341+ return internal_params
342+
356343
344+ def _merge_derived_categoricals (
345+ regimes : Mapping [str , Regime ],
346+ derived_categoricals : Mapping [str , DiscreteGrid ],
347+ ) -> MappingProxyType [str , Regime ]:
348+ """Merge model-level derived_categoricals into each regime.
357349
358- def _validate_param_types (internal_params : InternalParams ) -> None :
359- """Raise if any param leaf is not a Python scalar or JAX array.
350+ Args:
351+ regimes: Mapping of regime names to Regime instances.
352+ derived_categoricals: Model-level categorical grids to broadcast.
353+
354+ Returns:
355+ Immutable mapping of regime names to (possibly updated) Regime instances.
356+
357+ Raises:
358+ ModelInitializationError: If a regime already has a conflicting entry
359+ (same key, different categories).
360360
361- After processing, every leaf value (including inside MappingLeaf /
362- SequenceLeaf containers) must be a Python scalar (float, int, bool) or a
363- JAX array. Notably, numpy arrays and pandas Series are not accepted.
364361 """
365- for regime_name , regime_params in internal_params .items ():
366- for key , value in regime_params .items ():
367- _check_leaf (value , f"{ regime_name } __{ key } " )
368-
369-
370- def _check_leaf (value : object , path : str ) -> None :
371- """Check a single leaf value, recursing into MappingLeaf/SequenceLeaf."""
372- if isinstance (value , MappingLeaf ):
373- for k , v in value .data .items ():
374- _check_leaf (v , f"{ path } .{ k } " )
375- return
376- if isinstance (value , SequenceLeaf ):
377- for i , v in enumerate (value .data ):
378- _check_leaf (v , f"{ path } [{ i } ]" )
379- return
380- if isinstance (value , (float , int , bool )):
381- return
382- if hasattr (value , "dtype" ) and hasattr (value , "shape" ):
383- if isinstance (value , Array ):
384- return
385- type_name = type (value ).__module__ + "." + type (value ).__name__
386- msg = (
387- f"Parameter '{ path } ' is a { type_name } (shape { value .shape } ). "
388- f"Use jnp.array() or pass a pd.Series with a named index."
362+ if not derived_categoricals :
363+ return MappingProxyType (dict (regimes ))
364+ result = {}
365+ for name , regime in regimes .items ():
366+ merged = dict (regime .derived_categoricals )
367+ for var , grid in derived_categoricals .items ():
368+ existing = merged .get (var )
369+ if existing is not None and existing .categories != grid .categories :
370+ msg = (
371+ f"Model-level derived_categoricals['{ var } '] conflicts "
372+ f"with regime '{ name } ': { grid .categories } vs "
373+ f"{ existing .categories } ."
374+ )
375+ raise ModelInitializationError (msg )
376+ merged [var ] = grid
377+ result [name ] = dataclasses .replace (
378+ regime , derived_categoricals = MappingProxyType (merged )
389379 )
390- raise InvalidParamsError (msg )
391- type_name = type (value ).__module__ + "." + type (value ).__name__
392- msg = f"Parameter '{ path } ' has unexpected type { type_name } ."
393- raise InvalidParamsError (msg )
394-
395-
396- def _maybe_convert_dataframe (
397- initial_conditions : Mapping [str , Array ],
398- * ,
399- model : Model ,
400- ) -> Mapping [str , Array ]:
401- """Convert a DataFrame to initial_conditions dict if needed."""
402- if isinstance (initial_conditions , pd .DataFrame ):
403- return initial_conditions_from_dataframe (df = initial_conditions , model = model )
404- return initial_conditions
380+ return MappingProxyType (result )
405381
406382
407383def _validate_log_args (* , log_level : LogLevel , log_path : str | Path | None ) -> None :
0 commit comments