2020from lcm .ages import AgeGrid
2121from lcm .grids import Grid
2222from lcm .interfaces import StateActionSpace
23- from lcm .regime import Regime
24- from lcm .regime_building .Q_and_F import get_compute_intermediates
23+ from lcm .regime_building .Q_and_F import get_complete_targets , get_compute_intermediates
2524from lcm .regime_building .V import VInterpolationInfo
2625from lcm .typing import (
2726 FunctionsMapping ,
3433
3534def _build_compute_intermediates_per_period (
3635 * ,
37- regime : Regime ,
3836 flat_param_names : frozenset [str ],
3937 regimes_to_active_periods : MappingProxyType [RegimeName , tuple [int , ...]],
4038 functions : FunctionsMapping ,
4139 constraints : FunctionsMapping ,
4240 transitions : TransitionFunctionsMapping ,
4341 stochastic_transition_names : frozenset [str ],
44- compute_regime_transition_probs : RegimeTransitionFunction | None ,
42+ compute_regime_transition_probs : RegimeTransitionFunction ,
4543 regime_to_v_interpolation_info : MappingProxyType [RegimeName , VInterpolationInfo ],
4644 state_action_space : StateActionSpace ,
4745 grids : MappingProxyType [str , Grid ],
4846 ages : AgeGrid ,
4947 enable_jit : bool ,
5048) -> MappingProxyType [int , Callable ]:
51- """Build diagnostic intermediate closures for each period.
49+ """Build diagnostic intermediate closures for each period of a non-terminal regime .
5250
53- The closures fuse a productmap over the full state-action space with
51+ Each closure fuses a productmap over the full state-action space with
5452 on-device reductions (matching the `max_Q_over_a` productmap pattern)
55- and are JIT-compiled. Used in the error path when `validate_V` detects
56- NaN; returns an empty mapping for terminal regimes.
53+ and is JIT-compiled. Periods sharing the same target configuration
54+ reuse a single scalar closure. The caller is responsible for handling
55+ terminal regimes. Used in the error path when `validate_V` detects NaN.
5756
5857 Args:
59- regime: User regime; only the terminal flag is consulted.
60- flat_param_names: Frozenset of flat parameter names for the regime;
61- forwarded to `get_compute_intermediates` for the explicit
62- signature productmap requires.
58+ flat_param_names: Frozenset of flat parameter names for the regime.
6359 regimes_to_active_periods: Immutable mapping of regime names to
6460 their active period tuples.
6561 functions: Immutable mapping of internal user functions.
@@ -69,7 +65,7 @@ def _build_compute_intermediates_per_period(
6965 stochastic_transition_names: Frozenset of stochastic transition
7066 function names.
7167 compute_regime_transition_probs: Regime transition probability
72- function, or `None` for terminal regimes .
68+ function for the current regime .
7369 regime_to_v_interpolation_info: Mapping of regime names to
7470 V-interpolation info.
7571 state_action_space: State-action space used for productmap sizing.
@@ -79,37 +75,39 @@ def _build_compute_intermediates_per_period(
7975 enable_jit: Whether to JIT-compile the fused closure.
8076
8177 Returns:
82- Immutable mapping of period index to fused closure; empty for
83- terminal regimes.
78+ Immutable mapping of period index to fused closure.
8479
8580 """
86- if regime .terminal :
87- return MappingProxyType ({})
88-
89- assert compute_regime_transition_probs is not None # noqa: S101
90-
9181 state_batch_sizes = {
9282 name : grid .batch_size
9383 for name , grid in grids .items ()
9484 if name in state_action_space .state_names
9585 }
9686
87+ configs : dict [tuple [str , ...], list [int ]] = {}
88+ for period in range (ages .n_periods ):
89+ complete = get_complete_targets (
90+ period = period ,
91+ transitions = transitions ,
92+ regimes_to_active_periods = regimes_to_active_periods ,
93+ stochastic_transition_names = stochastic_transition_names ,
94+ regime_to_v_interpolation_info = regime_to_v_interpolation_info ,
95+ )
96+ configs .setdefault (complete , []).append (period )
97+
9798 variable_names = (
9899 * state_action_space .state_names ,
99100 * state_action_space .action_names ,
100101 )
101-
102- intermediates : dict [int , Callable ] = {}
103- for period , age in enumerate (ages .values ):
102+ built : dict [tuple [str , ...], Callable ] = {}
103+ for complete_targets in configs :
104104 scalar = get_compute_intermediates (
105- age = age ,
106- period = period ,
107105 flat_param_names = flat_param_names ,
108106 functions = functions ,
109107 constraints = constraints ,
108+ complete_targets = complete_targets ,
110109 transitions = transitions ,
111110 stochastic_transition_names = stochastic_transition_names ,
112- regimes_to_active_periods = regimes_to_active_periods ,
113111 compute_regime_transition_probs = compute_regime_transition_probs ,
114112 regime_to_v_interpolation_info = regime_to_v_interpolation_info ,
115113 )
@@ -119,13 +117,15 @@ def _build_compute_intermediates_per_period(
119117 state_names = state_action_space .state_names ,
120118 state_batch_sizes = state_batch_sizes ,
121119 )
122- fused = _wrap_with_reduction (
123- func = mapped ,
124- variable_names = variable_names ,
125- )
126- intermediates [period ] = jax .jit (fused ) if enable_jit else fused
120+ fused = _wrap_with_reduction (func = mapped , variable_names = variable_names )
121+ built [complete_targets ] = jax .jit (fused ) if enable_jit else fused
122+
123+ result : dict [int , Callable ] = {}
124+ for key , periods in configs .items ():
125+ for period in periods :
126+ result [period ] = built [key ]
127127
128- return MappingProxyType (intermediates )
128+ return MappingProxyType (result )
129129
130130
131131def _wrap_with_reduction (
0 commit comments