99import pandas as pd
1010from dags import concatenate_functions , get_annotations , with_signature
1111from dags .signature import rename_arguments
12- from dags .tree import qname_from_tree_path , tree_path_from_qname
12+ from dags .tree import QNAME_DELIMITER , qname_from_tree_path , tree_path_from_qname
1313from jax import Array
1414from jax import numpy as jnp
1515
@@ -85,7 +85,7 @@ def process_regimes(
8585 The processed regimes.
8686
8787 """
88- states_per_regime : dict [str , set [str ]] = {
88+ states_per_regime : dict [RegimeName , set [str ]] = {
8989 name : set (regime .states .keys ()) for name , regime in regimes .items ()
9090 }
9191
@@ -174,7 +174,7 @@ def process_regimes(
174174def _build_solve_functions (
175175 * ,
176176 regime : Regime ,
177- regime_name : str ,
177+ regime_name : RegimeName ,
178178 nested_transitions : dict [str , dict [str , UserFunction ] | UserFunction ],
179179 all_grids : MappingProxyType [RegimeName , MappingProxyType [str , Grid ]],
180180 regime_params_template : RegimeParamsTemplate ,
@@ -262,7 +262,7 @@ def _build_solve_functions(
262262def _build_simulate_functions (
263263 * ,
264264 regime : Regime ,
265- regime_name : str ,
265+ regime_name : RegimeName ,
266266 nested_transitions : dict [str , dict [str , UserFunction ] | UserFunction ],
267267 all_grids : MappingProxyType [RegimeName , MappingProxyType [str , Grid ]],
268268 regime_params_template : RegimeParamsTemplate ,
@@ -508,13 +508,21 @@ def _process_regime_core(
508508 )
509509
510510 # Shock transitions bypass the stub pipeline entirely. Build weight and
511- # next functions for ALL target regimes directly from each target's grid.
511+ # next functions for reachable target regimes from each target's grid.
512+ # Scope to targets already present in non-shock transitions to avoid
513+ # spurious entries for unreachable regimes.
512514 shock_names = variable_info .query ("is_shock" ).index .tolist ()
513- target_shock_grids : dict [tuple [str , str ], _ShockGrid ] = { # ty: ignore[invalid-assignment]
514- (regime , shock ): grids [shock ]
515+ reachable_targets = {
516+ tree_path_from_qname (k )[0 ]
517+ for k in flat_nested_transitions
518+ if QNAME_DELIMITER in k
519+ }
520+ target_shock_grids : dict [tuple [RegimeName , str ], _ShockGrid ] = {
521+ (regime , shock ): grid
515522 for regime , grids in all_grids .items ()
523+ if regime in reachable_targets
516524 for shock in shock_names
517- if isinstance (grids .get (shock ), _ShockGrid )
525+ if isinstance (grid := grids .get (shock ), _ShockGrid )
518526 }
519527 functions |= {
520528 f"weight_{ regime } __next_{ shock } " : _get_weights_func_for_shock (
@@ -567,7 +575,7 @@ def _process_regime_core(
567575def _extract_transitions_from_regime (
568576 * ,
569577 regime : Regime ,
570- states_per_regime : Mapping [str , set [str ]],
578+ states_per_regime : Mapping [RegimeName , set [str ]],
571579) -> dict [str , dict [str , UserFunction ] | UserFunction ]:
572580 """Extract transitions from `regime.state_transitions` and regime transition.
573581
@@ -600,7 +608,14 @@ def _extract_transitions_from_regime(
600608 {"next_regime" : regime .transition },
601609 )
602610
603- for target_regime_name , target_regime_state_names in states_per_regime .items ():
611+ reachable_targets = _get_reachable_targets (
612+ per_target_transitions = per_target_transitions ,
613+ simple_transitions = simple_transitions ,
614+ states_per_regime = states_per_regime ,
615+ )
616+
617+ for target_regime_name in reachable_targets :
618+ target_regime_state_names = states_per_regime [target_regime_name ]
604619 target_dict : dict [str , UserFunction ] = {}
605620 for state_name in target_regime_state_names :
606621 next_key = f"next_{ state_name } "
@@ -616,6 +631,34 @@ def _extract_transitions_from_regime(
616631 return nested
617632
618633
634+ def _get_reachable_targets (
635+ * ,
636+ per_target_transitions : dict [str , dict [str , UserFunction ]],
637+ simple_transitions : dict [str , UserFunction ],
638+ states_per_regime : Mapping [RegimeName , set [str ]],
639+ ) -> set [RegimeName ]:
640+ """Determine which target regimes need transition entries.
641+
642+ When per-target transitions exist, start from the explicitly named targets
643+ and add any target whose state needs are fully covered by simple
644+ (non-per-target) transitions. Without per-target transitions, all regimes
645+ are reachable.
646+
647+ """
648+ if not per_target_transitions :
649+ return set (states_per_regime .keys ())
650+
651+ targets : set [RegimeName ] = set ()
652+ for variants in per_target_transitions .values ():
653+ targets |= variants .keys ()
654+ for target_name , target_states in states_per_regime .items ():
655+ if target_name not in targets :
656+ needed = {f"next_{ s } " for s in target_states }
657+ if needed and needed .issubset (simple_transitions ):
658+ targets .add (target_name )
659+ return targets
660+
661+
619662def _classify_transitions (
620663 state_transitions : dict [str , UserFunction ],
621664) -> tuple [dict [str , UserFunction ], dict [str , dict [str , UserFunction ]]]:
0 commit comments