Skip to content

Commit 114f9e0

Browse files
hmgaudeckerclaude
andauthored
Share JIT compilations across periods with same target config (#318)
Co-authored-by: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent 993d52d commit 114f9e0

10 files changed

Lines changed: 375 additions & 316 deletions

File tree

src/lcm/regime_building/Q_and_F.py

Lines changed: 98 additions & 123 deletions
Large diffs are not rendered by default.

src/lcm/regime_building/diagnostics.py

Lines changed: 32 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,7 @@
2020
from lcm.ages import AgeGrid
2121
from lcm.grids import Grid
2222
from lcm.interfaces import StateActionSpace
23-
from lcm.regime import Regime
24-
from lcm.regime_building.Q_and_F import get_compute_intermediates
23+
from lcm.regime_building.Q_and_F import get_complete_targets, get_compute_intermediates
2524
from lcm.regime_building.V import VInterpolationInfo
2625
from lcm.typing import (
2726
FunctionsMapping,
@@ -34,32 +33,29 @@
3433

3534
def _build_compute_intermediates_per_period(
3635
*,
37-
regime: Regime,
3836
flat_param_names: frozenset[str],
3937
regimes_to_active_periods: MappingProxyType[RegimeName, tuple[int, ...]],
4038
functions: FunctionsMapping,
4139
constraints: FunctionsMapping,
4240
transitions: TransitionFunctionsMapping,
4341
stochastic_transition_names: frozenset[str],
44-
compute_regime_transition_probs: RegimeTransitionFunction | None,
42+
compute_regime_transition_probs: RegimeTransitionFunction,
4543
regime_to_v_interpolation_info: MappingProxyType[RegimeName, VInterpolationInfo],
4644
state_action_space: StateActionSpace,
4745
grids: MappingProxyType[str, Grid],
4846
ages: AgeGrid,
4947
enable_jit: bool,
5048
) -> MappingProxyType[int, Callable]:
51-
"""Build diagnostic intermediate closures for each period.
49+
"""Build diagnostic intermediate closures for each period of a non-terminal regime.
5250
53-
The closures fuse a productmap over the full state-action space with
51+
Each closure fuses a productmap over the full state-action space with
5452
on-device reductions (matching the `max_Q_over_a` productmap pattern)
55-
and are JIT-compiled. Used in the error path when `validate_V` detects
56-
NaN; returns an empty mapping for terminal regimes.
53+
and is JIT-compiled. Periods sharing the same target configuration
54+
reuse a single scalar closure. The caller is responsible for handling
55+
terminal regimes. Used in the error path when `validate_V` detects NaN.
5756
5857
Args:
59-
regime: User regime; only the terminal flag is consulted.
60-
flat_param_names: Frozenset of flat parameter names for the regime;
61-
forwarded to `get_compute_intermediates` for the explicit
62-
signature productmap requires.
58+
flat_param_names: Frozenset of flat parameter names for the regime.
6359
regimes_to_active_periods: Immutable mapping of regime names to
6460
their active period tuples.
6561
functions: Immutable mapping of internal user functions.
@@ -69,7 +65,7 @@ def _build_compute_intermediates_per_period(
6965
stochastic_transition_names: Frozenset of stochastic transition
7066
function names.
7167
compute_regime_transition_probs: Regime transition probability
72-
function, or `None` for terminal regimes.
68+
function for the current regime.
7369
regime_to_v_interpolation_info: Mapping of regime names to
7470
V-interpolation info.
7571
state_action_space: State-action space used for productmap sizing.
@@ -79,37 +75,39 @@ def _build_compute_intermediates_per_period(
7975
enable_jit: Whether to JIT-compile the fused closure.
8076
8177
Returns:
82-
Immutable mapping of period index to fused closure; empty for
83-
terminal regimes.
78+
Immutable mapping of period index to fused closure.
8479
8580
"""
86-
if regime.terminal:
87-
return MappingProxyType({})
88-
89-
assert compute_regime_transition_probs is not None # noqa: S101
90-
9181
state_batch_sizes = {
9282
name: grid.batch_size
9383
for name, grid in grids.items()
9484
if name in state_action_space.state_names
9585
}
9686

87+
configs: dict[tuple[str, ...], list[int]] = {}
88+
for period in range(ages.n_periods):
89+
complete = get_complete_targets(
90+
period=period,
91+
transitions=transitions,
92+
regimes_to_active_periods=regimes_to_active_periods,
93+
stochastic_transition_names=stochastic_transition_names,
94+
regime_to_v_interpolation_info=regime_to_v_interpolation_info,
95+
)
96+
configs.setdefault(complete, []).append(period)
97+
9798
variable_names = (
9899
*state_action_space.state_names,
99100
*state_action_space.action_names,
100101
)
101-
102-
intermediates: dict[int, Callable] = {}
103-
for period, age in enumerate(ages.values):
102+
built: dict[tuple[str, ...], Callable] = {}
103+
for complete_targets in configs:
104104
scalar = get_compute_intermediates(
105-
age=age,
106-
period=period,
107105
flat_param_names=flat_param_names,
108106
functions=functions,
109107
constraints=constraints,
108+
complete_targets=complete_targets,
110109
transitions=transitions,
111110
stochastic_transition_names=stochastic_transition_names,
112-
regimes_to_active_periods=regimes_to_active_periods,
113111
compute_regime_transition_probs=compute_regime_transition_probs,
114112
regime_to_v_interpolation_info=regime_to_v_interpolation_info,
115113
)
@@ -119,13 +117,15 @@ def _build_compute_intermediates_per_period(
119117
state_names=state_action_space.state_names,
120118
state_batch_sizes=state_batch_sizes,
121119
)
122-
fused = _wrap_with_reduction(
123-
func=mapped,
124-
variable_names=variable_names,
125-
)
126-
intermediates[period] = jax.jit(fused) if enable_jit else fused
120+
fused = _wrap_with_reduction(func=mapped, variable_names=variable_names)
121+
built[complete_targets] = jax.jit(fused) if enable_jit else fused
122+
123+
result: dict[int, Callable] = {}
124+
for key, periods in configs.items():
125+
for period in periods:
126+
result[period] = built[key]
127127

128-
return MappingProxyType(intermediates)
128+
return MappingProxyType(result)
129129

130130

131131
def _wrap_with_reduction(

0 commit comments

Comments
 (0)