Skip to content

Commit 6af04b2

Browse files
hmgaudeckerclaude
andauthored
Support heterogeneous state sets in initial conditions (#315)
Co-authored-by: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent 719b533 commit 6af04b2

5 files changed

Lines changed: 484 additions & 29 deletions

File tree

src/lcm/ages.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,10 @@
2020
}
2121
)
2222

23+
# Names that behave like states in initial conditions but are not declared on
24+
# any `Regime.states`. `age` is required for every subject regardless of regime.
25+
PSEUDO_STATE_NAMES: frozenset[str] = frozenset({"age"})
26+
2327

2428
class AgeGrid:
2529
"""Age grid for life-cycle models.

src/lcm/pandas_utils.py

Lines changed: 41 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -11,12 +11,13 @@
1111
from dags.tree import qname_from_tree_path, tree_path_from_qname
1212
from jax import Array
1313

14-
from lcm.ages import AgeGrid
14+
from lcm.ages import PSEUDO_STATE_NAMES, AgeGrid
1515
from lcm.grids import DiscreteGrid, IrregSpacedGrid
1616
from lcm.params import MappingLeaf
1717
from lcm.params.sequence_leaf import SequenceLeaf
1818
from lcm.regime import Regime
1919
from lcm.shocks import _ShockGrid
20+
from lcm.simulation.initial_conditions import MISSING_CAT_CODE
2021
from lcm.typing import InternalParams, RegimeNamesToIds
2122
from lcm.utils.error_handling import (
2223
_get_func_indexing_params,
@@ -39,7 +40,7 @@ def has_series(params: Mapping) -> bool:
3940
return False
4041

4142

42-
def initial_conditions_from_dataframe(
43+
def initial_conditions_from_dataframe( # noqa: C901
4344
*,
4445
df: pd.DataFrame,
4546
regimes: Mapping[str, Regime],
@@ -92,9 +93,9 @@ def initial_conditions_from_dataframe(
9293
n_subjects = len(df)
9394
state_cols = [col for col in df.columns if col != "regime"]
9495

95-
# Pre-allocate result arrays
96+
# Pre-allocate result arrays (NaN default surfaces bugs for missing states)
9697
result_arrays: dict[str, np.ndarray] = {
97-
col: np.empty(n_subjects, dtype=float) for col in state_cols
98+
col: np.full(n_subjects, np.nan) for col in state_cols
9899
}
99100
discrete_state_names: set[str] = set()
100101

@@ -109,7 +110,12 @@ def initial_conditions_from_dataframe(
109110
}
110111
discrete_state_names |= discrete_grids.keys()
111112

113+
regime_state_names = set(regime.states.keys()) | PSEUDO_STATE_NAMES
114+
112115
for col in state_cols:
116+
if col not in regime_state_names:
117+
continue
118+
113119
values = group[col]
114120
if hasattr(values, "cat"):
115121
values = values.astype(str)
@@ -126,6 +132,14 @@ def initial_conditions_from_dataframe(
126132
else:
127133
result_arrays[col][idx] = values.to_numpy(dtype=float)
128134

135+
# Replace remaining NaN in discrete columns with an explicit int sentinel
136+
# before casting to int32. This avoids platform-undefined NaN→int behavior
137+
# and the associated RuntimeWarning.
138+
for col in discrete_state_names:
139+
if col in result_arrays:
140+
nan_mask = np.isnan(result_arrays[col])
141+
result_arrays[col][nan_mask] = MISSING_CAT_CODE
142+
129143
initial_conditions: dict[str, Array] = {
130144
col: jnp.array(arr, dtype=jnp.int32)
131145
if col in discrete_state_names
@@ -786,17 +800,35 @@ def _validate_state_columns(
786800
unknown = state_columns - expected
787801
if unknown:
788802
msg = (
789-
f"Unknown columns not matching any model state: {sorted(unknown)}. "
803+
f"Unknown columns not matching any state of an initial regime: "
804+
f"{sorted(unknown)}. "
790805
f"Expected states: {sorted(expected)}."
791806
)
792807
raise ValueError(msg)
793808

794809
missing = expected - state_columns
795810
if missing:
796-
msg = f"Missing required state columns: {sorted(missing)}."
811+
required_by: dict[str, list[str]] = {name: [] for name in missing}
812+
for regime_name in set(initial_regimes):
813+
for name in regimes[regime_name].states:
814+
if name in required_by:
815+
required_by[name].append(regime_name)
816+
details = ", ".join(
817+
_format_missing_state_detail(name=name, required_by=required_by[name])
818+
for name in sorted(missing)
819+
)
820+
msg = f"Missing required state columns: {details}."
797821
raise ValueError(msg)
798822

799823

824+
def _format_missing_state_detail(*, name: str, required_by: list[str]) -> str:
825+
if name in PSEUDO_STATE_NAMES:
826+
return f"'{name}' (required for every subject)"
827+
if required_by:
828+
return f"'{name}' (required by {sorted(required_by)})"
829+
return f"'{name}' (required by an initial regime)"
830+
831+
800832
def _collect_state_names(
801833
*,
802834
regimes: Mapping[str, Regime],
@@ -805,11 +837,11 @@ def _collect_state_names(
805837
"""Collect all state names (including shock grids) from initial regimes.
806838
807839
Returns:
808-
Set of all state names from the initial regimes, plus `'age'`
809-
(always required).
840+
Set of all state names from the initial regimes, plus the pseudo-state
841+
names from `PSEUDO_STATE_NAMES` (always required).
810842
811843
"""
812-
names: set[str] = {"age"}
844+
names: set[str] = set(PSEUDO_STATE_NAMES)
813845
for regime_name in set(initial_regimes):
814846
names.update(regimes[regime_name].states.keys())
815847
return names

src/lcm/simulation/initial_conditions.py

Lines changed: 96 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -7,14 +7,15 @@
77

88
from collections.abc import Callable, Mapping, Sequence
99
from types import MappingProxyType
10+
from typing import Never
1011

1112
import jax
1213
import numpy as np
1314
import pandas as pd
1415
from jax import Array
1516
from jax import numpy as jnp
1617

17-
from lcm.ages import AgeGrid
18+
from lcm.ages import PSEUDO_STATE_NAMES, AgeGrid
1819
from lcm.exceptions import (
1920
InvalidInitialConditionsError,
2021
format_messages,
@@ -143,7 +144,10 @@ def validate_initial_conditions(
143144

144145
# Validate discrete state values
145146
_validate_discrete_state_values(
146-
initial_states=initial_states, internal_regimes=internal_regimes
147+
initial_states=initial_states,
148+
internal_regimes=internal_regimes,
149+
regime_id_arr=regime_arr,
150+
regime_names_to_ids=regime_names_to_ids,
147151
)
148152

149153
# Validate feasibility
@@ -194,7 +198,7 @@ def _format_missing_states_message(missing: set[str], required: set[str]) -> str
194198
"knows each subject's starting age. Example: "
195199
"initial_states={'age': jnp.array([25.0, 25.0]), ...}"
196200
)
197-
missing_model_states = sorted(missing - {"age"})
201+
missing_model_states = sorted(missing - PSEUDO_STATE_NAMES)
198202
if missing_model_states:
199203
parts.append(f"Missing model states: {missing_model_states}.")
200204
parts.append(f"Required initial states are: {sorted(required)}")
@@ -230,12 +234,12 @@ def _collect_state_name_errors(
230234
errors: list[str] = []
231235

232236
# All known states (union across all regimes) — used for the "extra" check
233-
all_known_states: set[str] = {"age"}
237+
all_known_states: set[str] = set(PSEUDO_STATE_NAMES)
234238
for internal_regime in internal_regimes.values():
235239
all_known_states.update(_get_regime_state_names(internal_regime))
236240

237241
# Required states — only from regimes subjects actually start in
238-
required_states: set[str] = {"age"}
242+
required_states: set[str] = set(PSEUDO_STATE_NAMES)
239243
used_ids = jnp.unique(regime_id_arr)
240244
used_regime_names = {
241245
ids_to_regime_names[int(i)] for i in used_ids if int(i) in ids_to_regime_names
@@ -414,35 +418,53 @@ def _validate_discrete_state_values(
414418
*,
415419
initial_states: Mapping[str, Array],
416420
internal_regimes: MappingProxyType[RegimeName, InternalRegime],
421+
regime_id_arr: Array,
422+
regime_names_to_ids: Mapping[str, int],
417423
) -> None:
418424
"""Validate that discrete state values are valid codes.
419425
426+
Only check subjects in regimes that actually have the state.
427+
420428
Args:
421429
initial_states: Mapping of state names to arrays.
422430
internal_regimes: Immutable mapping of regime names to internal regime
423431
instances.
432+
regime_id_arr: Array of regime IDs for each subject.
433+
regime_names_to_ids: Mapping from regime names to integer IDs.
424434
425435
Raises:
426436
InvalidInitialConditionsError: If any discrete state contains invalid codes.
427437
428438
"""
429-
discrete_valid_codes: dict[str, set[int]] = {}
430-
for internal_regime in internal_regimes.values():
439+
# Build per-state: valid codes + regime IDs that have this state
440+
discrete_info: dict[str, tuple[set[int], set[int]]] = {}
441+
for regime_name, internal_regime in internal_regimes.items():
442+
regime_id = regime_names_to_ids[regime_name]
431443
for state_name in internal_regime.variable_info.query(
432444
"is_state and is_discrete"
433445
).index:
434446
grid = internal_regime.grids[state_name]
435447
if isinstance(grid, DiscreteGrid):
436-
existing = discrete_valid_codes.get(state_name, set())
437-
discrete_valid_codes[state_name] = existing | set(grid.codes)
448+
codes, regime_ids = discrete_info.get(state_name, (set(), set()))
449+
discrete_info[state_name] = (
450+
codes | set(grid.codes),
451+
regime_ids | {regime_id},
452+
)
438453

439-
for state_name, valid_codes in discrete_valid_codes.items():
454+
for state_name, (valid_codes, regime_ids) in discrete_info.items():
440455
if state_name not in initial_states:
441456
continue
442457
values = initial_states[state_name]
443-
invalid_mask = jnp.isin(values, jnp.array(sorted(valid_codes)), invert=True)
458+
# Only validate subjects in regimes that have this state
459+
in_relevant_regime = jnp.isin(regime_id_arr, jnp.array(sorted(regime_ids)))
460+
relevant_values = values[in_relevant_regime]
461+
if relevant_values.size == 0:
462+
continue
463+
invalid_mask = jnp.isin(
464+
relevant_values, jnp.array(sorted(valid_codes)), invert=True
465+
)
444466
if jnp.any(invalid_mask):
445-
invalid_vals = sorted({int(v) for v in values[invalid_mask]})
467+
invalid_vals = sorted({int(v) for v in relevant_values[invalid_mask]})
446468
raise InvalidInitialConditionsError(
447469
f"Invalid values {invalid_vals} for discrete state "
448470
f"'{state_name}'. Valid codes are: {sorted(valid_codes)}"
@@ -523,7 +545,7 @@ def _is_any_action_feasible(per_subject_kwargs: dict[str, Array]) -> Array:
523545
return jnp.concatenate(results)
524546

525547

526-
def _check_regime_feasibility(
548+
def _check_regime_feasibility( # noqa: C901
527549
*,
528550
internal_regime: InternalRegime,
529551
regime_name: str,
@@ -587,13 +609,21 @@ def _check_regime_feasibility(
587609
}
588610

589611
if subject_states:
590-
any_feasible = _batched_feasibility_check(
591-
feasibility_func=feasibility_func,
592-
subject_states=subject_states,
593-
action_kwargs=action_kwargs,
594-
filtered_params=filtered_params,
595-
flat_actions=flat_actions,
596-
)
612+
try:
613+
any_feasible = _batched_feasibility_check(
614+
feasibility_func=feasibility_func,
615+
subject_states=subject_states,
616+
action_kwargs=action_kwargs,
617+
filtered_params=filtered_params,
618+
flat_actions=flat_actions,
619+
)
620+
except TypeError as exc:
621+
_raise_feasibility_type_error(
622+
exc=exc,
623+
regime_name=regime_name,
624+
internal_regime=internal_regime,
625+
subject_states=subject_states,
626+
)
597627
infeasible_mask = np.asarray(~any_feasible)
598628
infeasible_indices = np.asarray(idx_arr)[infeasible_mask].tolist()
599629
else:
@@ -620,6 +650,52 @@ def _check_combo(action_kw: dict[str, Array]) -> Array:
620650
)
621651

622652

653+
def _raise_feasibility_type_error(
654+
*,
655+
exc: TypeError,
656+
regime_name: str,
657+
internal_regime: InternalRegime,
658+
subject_states: dict[str, Array],
659+
) -> Never:
660+
"""Re-raise a TypeError from feasibility checking with diagnostic context.
661+
662+
Args:
663+
exc: The original TypeError from the feasibility check.
664+
regime_name: Name of the regime being checked.
665+
internal_regime: The internal regime containing variable info.
666+
subject_states: Mapping of state names to arrays for subjects in
667+
this regime.
668+
669+
Raises:
670+
InvalidInitialConditionsError: Always — wraps `exc` with a dtype hint
671+
when any discrete state has a non-integer dtype.
672+
673+
"""
674+
discrete_names = {
675+
name
676+
for name, grid in internal_regime.grids.items()
677+
if isinstance(grid, DiscreteGrid)
678+
}
679+
680+
bad_dtypes: list[str] = []
681+
for name, arr in subject_states.items():
682+
if name in discrete_names and not jnp.issubdtype(arr.dtype, jnp.integer):
683+
bad_dtypes.append(f" {name!r}: dtype={arr.dtype} (expected integer)")
684+
685+
hint = ""
686+
if bad_dtypes:
687+
hint = (
688+
"\n\nDiscrete states with wrong dtype:\n"
689+
+ "\n".join(bad_dtypes)
690+
+ "\n\nDiscrete states are used as array indices and must have integer "
691+
"dtype. Check that initial conditions encode categorical states as int "
692+
"codes, not floats."
693+
)
694+
695+
msg = f"TypeError in feasibility check for regime {regime_name!r}: {exc}{hint}"
696+
raise InvalidInitialConditionsError(msg) from exc
697+
698+
623699
def _format_infeasibility_message(
624700
*,
625701
infeasible_indices: Sequence[int],

0 commit comments

Comments
 (0)