77
88from collections .abc import Callable , Mapping , Sequence
99from types import MappingProxyType
10+ from typing import Never
1011
1112import jax
1213import numpy as np
1314import pandas as pd
1415from jax import Array
1516from jax import numpy as jnp
1617
17- from lcm .ages import AgeGrid
18+ from lcm .ages import PSEUDO_STATE_NAMES , AgeGrid
1819from lcm .exceptions import (
1920 InvalidInitialConditionsError ,
2021 format_messages ,
@@ -143,7 +144,10 @@ def validate_initial_conditions(
143144
144145 # Validate discrete state values
145146 _validate_discrete_state_values (
146- initial_states = initial_states , internal_regimes = internal_regimes
147+ initial_states = initial_states ,
148+ internal_regimes = internal_regimes ,
149+ regime_id_arr = regime_arr ,
150+ regime_names_to_ids = regime_names_to_ids ,
147151 )
148152
149153 # Validate feasibility
@@ -194,7 +198,7 @@ def _format_missing_states_message(missing: set[str], required: set[str]) -> str
194198 "knows each subject's starting age. Example: "
195199 "initial_states={'age': jnp.array([25.0, 25.0]), ...}"
196200 )
197- missing_model_states = sorted (missing - { "age" } )
201+ missing_model_states = sorted (missing - PSEUDO_STATE_NAMES )
198202 if missing_model_states :
199203 parts .append (f"Missing model states: { missing_model_states } ." )
200204 parts .append (f"Required initial states are: { sorted (required )} " )
@@ -230,12 +234,12 @@ def _collect_state_name_errors(
230234 errors : list [str ] = []
231235
232236 # All known states (union across all regimes) — used for the "extra" check
233- all_known_states : set [str ] = { "age" }
237+ all_known_states : set [str ] = set ( PSEUDO_STATE_NAMES )
234238 for internal_regime in internal_regimes .values ():
235239 all_known_states .update (_get_regime_state_names (internal_regime ))
236240
237241 # Required states — only from regimes subjects actually start in
238- required_states : set [str ] = { "age" }
242+ required_states : set [str ] = set ( PSEUDO_STATE_NAMES )
239243 used_ids = jnp .unique (regime_id_arr )
240244 used_regime_names = {
241245 ids_to_regime_names [int (i )] for i in used_ids if int (i ) in ids_to_regime_names
@@ -414,35 +418,53 @@ def _validate_discrete_state_values(
414418 * ,
415419 initial_states : Mapping [str , Array ],
416420 internal_regimes : MappingProxyType [RegimeName , InternalRegime ],
421+ regime_id_arr : Array ,
422+ regime_names_to_ids : Mapping [str , int ],
417423) -> None :
418424 """Validate that discrete state values are valid codes.
419425
426+ Only check subjects in regimes that actually have the state.
427+
420428 Args:
421429 initial_states: Mapping of state names to arrays.
422430 internal_regimes: Immutable mapping of regime names to internal regime
423431 instances.
432+ regime_id_arr: Array of regime IDs for each subject.
433+ regime_names_to_ids: Mapping from regime names to integer IDs.
424434
425435 Raises:
426436 InvalidInitialConditionsError: If any discrete state contains invalid codes.
427437
428438 """
429- discrete_valid_codes : dict [str , set [int ]] = {}
430- for internal_regime in internal_regimes .values ():
439+ # Build per-state: valid codes + regime IDs that have this state
440+ discrete_info : dict [str , tuple [set [int ], set [int ]]] = {}
441+ for regime_name , internal_regime in internal_regimes .items ():
442+ regime_id = regime_names_to_ids [regime_name ]
431443 for state_name in internal_regime .variable_info .query (
432444 "is_state and is_discrete"
433445 ).index :
434446 grid = internal_regime .grids [state_name ]
435447 if isinstance (grid , DiscreteGrid ):
436- existing = discrete_valid_codes .get (state_name , set ())
437- discrete_valid_codes [state_name ] = existing | set (grid .codes )
448+ codes , regime_ids = discrete_info .get (state_name , (set (), set ()))
449+ discrete_info [state_name ] = (
450+ codes | set (grid .codes ),
451+ regime_ids | {regime_id },
452+ )
438453
439- for state_name , valid_codes in discrete_valid_codes .items ():
454+ for state_name , ( valid_codes , regime_ids ) in discrete_info .items ():
440455 if state_name not in initial_states :
441456 continue
442457 values = initial_states [state_name ]
443- invalid_mask = jnp .isin (values , jnp .array (sorted (valid_codes )), invert = True )
458+ # Only validate subjects in regimes that have this state
459+ in_relevant_regime = jnp .isin (regime_id_arr , jnp .array (sorted (regime_ids )))
460+ relevant_values = values [in_relevant_regime ]
461+ if relevant_values .size == 0 :
462+ continue
463+ invalid_mask = jnp .isin (
464+ relevant_values , jnp .array (sorted (valid_codes )), invert = True
465+ )
444466 if jnp .any (invalid_mask ):
445- invalid_vals = sorted ({int (v ) for v in values [invalid_mask ]})
467+ invalid_vals = sorted ({int (v ) for v in relevant_values [invalid_mask ]})
446468 raise InvalidInitialConditionsError (
447469 f"Invalid values { invalid_vals } for discrete state "
448470 f"'{ state_name } '. Valid codes are: { sorted (valid_codes )} "
@@ -523,7 +545,7 @@ def _is_any_action_feasible(per_subject_kwargs: dict[str, Array]) -> Array:
523545 return jnp .concatenate (results )
524546
525547
526- def _check_regime_feasibility (
548+ def _check_regime_feasibility ( # noqa: C901
527549 * ,
528550 internal_regime : InternalRegime ,
529551 regime_name : str ,
@@ -587,13 +609,21 @@ def _check_regime_feasibility(
587609 }
588610
589611 if subject_states :
590- any_feasible = _batched_feasibility_check (
591- feasibility_func = feasibility_func ,
592- subject_states = subject_states ,
593- action_kwargs = action_kwargs ,
594- filtered_params = filtered_params ,
595- flat_actions = flat_actions ,
596- )
612+ try :
613+ any_feasible = _batched_feasibility_check (
614+ feasibility_func = feasibility_func ,
615+ subject_states = subject_states ,
616+ action_kwargs = action_kwargs ,
617+ filtered_params = filtered_params ,
618+ flat_actions = flat_actions ,
619+ )
620+ except TypeError as exc :
621+ _raise_feasibility_type_error (
622+ exc = exc ,
623+ regime_name = regime_name ,
624+ internal_regime = internal_regime ,
625+ subject_states = subject_states ,
626+ )
597627 infeasible_mask = np .asarray (~ any_feasible )
598628 infeasible_indices = np .asarray (idx_arr )[infeasible_mask ].tolist ()
599629 else :
@@ -620,6 +650,52 @@ def _check_combo(action_kw: dict[str, Array]) -> Array:
620650 )
621651
622652
653+ def _raise_feasibility_type_error (
654+ * ,
655+ exc : TypeError ,
656+ regime_name : str ,
657+ internal_regime : InternalRegime ,
658+ subject_states : dict [str , Array ],
659+ ) -> Never :
660+ """Re-raise a TypeError from feasibility checking with diagnostic context.
661+
662+ Args:
663+ exc: The original TypeError from the feasibility check.
664+ regime_name: Name of the regime being checked.
665+ internal_regime: The internal regime containing variable info.
666+ subject_states: Mapping of state names to arrays for subjects in
667+ this regime.
668+
669+ Raises:
670+ InvalidInitialConditionsError: Always — wraps `exc` with a dtype hint
671+ when any discrete state has a non-integer dtype.
672+
673+ """
674+ discrete_names = {
675+ name
676+ for name , grid in internal_regime .grids .items ()
677+ if isinstance (grid , DiscreteGrid )
678+ }
679+
680+ bad_dtypes : list [str ] = []
681+ for name , arr in subject_states .items ():
682+ if name in discrete_names and not jnp .issubdtype (arr .dtype , jnp .integer ):
683+ bad_dtypes .append (f" { name !r} : dtype={ arr .dtype } (expected integer)" )
684+
685+ hint = ""
686+ if bad_dtypes :
687+ hint = (
688+ "\n \n Discrete states with wrong dtype:\n "
689+ + "\n " .join (bad_dtypes )
690+ + "\n \n Discrete states are used as array indices and must have integer "
691+ "dtype. Check that initial conditions encode categorical states as int "
692+ "codes, not floats."
693+ )
694+
695+ msg = f"TypeError in feasibility check for regime { regime_name !r} : { exc } { hint } "
696+ raise InvalidInitialConditionsError (msg ) from exc
697+
698+
623699def _format_infeasibility_message (
624700 * ,
625701 infeasible_indices : Sequence [int ],
0 commit comments