Skip to content

Commit d5fe32c

Browse files
hmgaudeckerclaude
andauthored
Accept shock grid columns in initial_conditions_from_dataframe (#306)
Co-authored-by: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent a4cd1a3 commit d5fe32c

2 files changed

Lines changed: 50 additions & 20 deletions

File tree

src/lcm/pandas_utils.py

Lines changed: 16 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -800,42 +800,38 @@ def _validate_state_columns(
800800
initial_regimes: list[str],
801801
) -> None:
802802
"""Validate that DataFrame columns match model states."""
803-
all_states = _collect_all_state_names(
804-
regimes=regimes, initial_regimes=initial_regimes
805-
)
803+
expected = _collect_state_names(regimes=regimes, initial_regimes=initial_regimes)
806804

807-
unknown = state_columns - all_states
805+
unknown = state_columns - expected
808806
if unknown:
809807
msg = (
810808
f"Unknown columns not matching any model state: {sorted(unknown)}. "
811-
f"Expected states: {sorted(all_states)}."
809+
f"Expected states: {sorted(expected)}."
812810
)
813811
raise ValueError(msg)
814812

815-
missing = all_states - state_columns
813+
missing = expected - state_columns
816814
if missing:
817-
msg = (
818-
f"Missing required state columns: {sorted(missing)}. "
819-
f"All non-shock states must be provided."
820-
)
815+
msg = f"Missing required state columns: {sorted(missing)}."
821816
raise ValueError(msg)
822817

823818

824-
def _collect_all_state_names(
819+
def _collect_state_names(
825820
*,
826821
regimes: Mapping[str, Regime],
827822
initial_regimes: list[str],
828823
) -> set[str]:
829-
"""Collect all non-shock state names from regimes present in initial_regimes."""
830-
state_names: set[str] = set()
824+
"""Collect all state names (including shock grids) from initial regimes.
825+
826+
Returns:
827+
Set of all state names from the initial regimes, plus `'age'`
828+
(always required).
829+
830+
"""
831+
names: set[str] = {"age"}
831832
for regime_name in set(initial_regimes):
832-
regime = regimes[regime_name]
833-
for name, grid in regime.states.items():
834-
if not isinstance(grid, _ShockGrid):
835-
state_names.add(name)
836-
# Always include age
837-
state_names.add("age")
838-
return state_names
833+
names.update(regimes[regime_name].states.keys())
834+
return names
839835

840836

841837
def _build_discrete_grid_lookup(

tests/test_pandas_utils.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
get_model as get_basic_model,
3333
)
3434
from tests.test_models.regime_markov import get_model as get_regime_markov_model
35+
from tests.test_models.shock_grids import get_model as get_shock_model
3536
from tests.test_models.stochastic import get_model as get_stochastic_model
3637

3738

@@ -270,6 +271,39 @@ def test_missing_state_column_raises():
270271
initial_conditions_from_dataframe(df=df, model=model)
271272

272273

274+
def test_shock_state_columns_accepted():
275+
"""Shock grid columns are accepted as continuous float columns."""
276+
model = get_shock_model(n_periods=4, distribution_type="uniform")
277+
df = pd.DataFrame(
278+
{
279+
"regime": ["alive", "alive"],
280+
"wealth": [2.0, 4.0],
281+
"health": ["bad", "good"],
282+
"income": [0.3, 0.7],
283+
"age": [0.0, 0.0],
284+
}
285+
)
286+
conditions = initial_conditions_from_dataframe(df=df, model=model)
287+
assert jnp.allclose(conditions["income"], jnp.array([0.3, 0.7]))
288+
assert jnp.allclose(conditions["wealth"], jnp.array([2.0, 4.0]))
289+
assert "regime" in conditions
290+
291+
292+
def test_shock_state_columns_required():
293+
"""DataFrame without shock columns raises (shocks are required)."""
294+
model = get_shock_model(n_periods=4, distribution_type="uniform")
295+
df = pd.DataFrame(
296+
{
297+
"regime": ["alive", "alive"],
298+
"wealth": [2.0, 4.0],
299+
"health": ["bad", "good"],
300+
"age": [0.0, 0.0],
301+
}
302+
)
303+
with pytest.raises(ValueError, match=r"Missing required state columns.*income"):
304+
initial_conditions_from_dataframe(df=df, model=model)
305+
306+
273307
def test_round_trip_with_discrete_model():
274308
"""Verify DataFrame-based initial states match raw arrays."""
275309
from tests.test_models.deterministic.discrete import ( # noqa: PLC0415

0 commit comments

Comments
 (0)