@@ -800,42 +800,38 @@ def _validate_state_columns(
800800 initial_regimes : list [str ],
801801) -> None :
802802 """Validate that DataFrame columns match model states."""
803- all_states = _collect_all_state_names (
804- regimes = regimes , initial_regimes = initial_regimes
805- )
803+ expected = _collect_state_names (regimes = regimes , initial_regimes = initial_regimes )
806804
807- unknown = state_columns - all_states
805+ unknown = state_columns - expected
808806 if unknown :
809807 msg = (
810808 f"Unknown columns not matching any model state: { sorted (unknown )} . "
811- f"Expected states: { sorted (all_states )} ."
809+ f"Expected states: { sorted (expected )} ."
812810 )
813811 raise ValueError (msg )
814812
815- missing = all_states - state_columns
813+ missing = expected - state_columns
816814 if missing :
817- msg = (
818- f"Missing required state columns: { sorted (missing )} . "
819- f"All non-shock states must be provided."
820- )
815+ msg = f"Missing required state columns: { sorted (missing )} ."
821816 raise ValueError (msg )
822817
823818
824- def _collect_all_state_names (
819+ def _collect_state_names (
825820 * ,
826821 regimes : Mapping [str , Regime ],
827822 initial_regimes : list [str ],
828823) -> set [str ]:
829- """Collect all non-shock state names from regimes present in initial_regimes."""
830- state_names : set [str ] = set ()
824+ """Collect all state names (including shock grids) from initial regimes.
825+
826+ Returns:
827+ Set of all state names from the initial regimes, plus `'age'`
828+ (always required).
829+
830+ """
831+ names : set [str ] = {"age" }
831832 for regime_name in set (initial_regimes ):
832- regime = regimes [regime_name ]
833- for name , grid in regime .states .items ():
834- if not isinstance (grid , _ShockGrid ):
835- state_names .add (name )
836- # Always include age
837- state_names .add ("age" )
838- return state_names
833+ names .update (regimes [regime_name ].states .keys ())
834+ return names
839835
840836
841837def _build_discrete_grid_lookup (
0 commit comments