Skip to content

Commit 1d8e3ac

Browse files
hmgaudeckerclaude
andauthored
Skip unreachable target regimes in continuation value loop (#316)
Co-authored-by: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent 49cb0e5 commit 1d8e3ac

5 files changed

Lines changed: 368 additions & 28 deletions

File tree

src/lcm/regime_building/Q_and_F.py

Lines changed: 32 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -66,15 +66,35 @@ def get_Q_and_F(
6666
joint_weights_from_marginals = {}
6767
next_V = {}
6868

69-
target_regime_names = tuple(transitions)
70-
active_regimes_next_period = tuple(
71-
target_regime_name
72-
for target_regime_name in target_regime_names
73-
if period + 1 in regimes_to_active_periods[target_regime_name]
69+
# Enumerate all active targets, not just those in transitions — targets
70+
# entirely absent from per-target dicts must also be detected.
71+
all_active_next_period = tuple(
72+
regime_name
73+
for regime_name in regime_to_v_interpolation_info
74+
if period + 1 in regimes_to_active_periods.get(regime_name, ())
7475
)
76+
77+
# Keep only targets whose stochastic state needs are all covered by
78+
# `transitions`. Targets with missing stochastic transitions are dropped
79+
# from the traced function; `validate_regime_transitions_all_periods`
80+
# (via `_validate_no_reachable_incomplete_targets` in
81+
# `lcm.utils.error_handling`) raises pre-solve if any such target has
82+
# non-zero transition probability.
83+
complete_targets: list[RegimeName] = []
84+
for regime_name in all_active_next_period:
85+
target_stochastic_needs = {
86+
f"next_{s}"
87+
for s in regime_to_v_interpolation_info[regime_name].state_names
88+
if f"next_{s}" in stochastic_transition_names
89+
}
90+
if regime_name in transitions and target_stochastic_needs.issubset(
91+
transitions[regime_name]
92+
):
93+
complete_targets.append(regime_name)
94+
7595
next_V_extra_param_names: dict[str, frozenset[str]] = {}
7696

77-
for target_regime_name in active_regimes_next_period:
97+
for target_regime_name in complete_targets:
7898
# Transitions from the current regime to the target regime
7999
target_transitions = transitions[target_regime_name]
80100

@@ -170,14 +190,16 @@ def Q_and_F(
170190
period=period,
171191
age=age,
172192
)
173-
# Filter to active regimes only — inactive regimes must have 0
174-
# probability (validated before solve).
193+
# `complete_targets` is resolved at trace time (it is a closure over
194+
# a Python list); incomplete-target validation happens outside JIT
195+
# in `_validate_no_reachable_incomplete_targets` so that the traced
196+
# graph contains no runtime error-raising callbacks.
175197
active_regime_probs = MappingProxyType(
176-
{r: regime_transition_probs[r] for r in active_regimes_next_period}
198+
{r: regime_transition_probs[r] for r in complete_targets}
177199
)
178200

179201
E_next_V = jnp.zeros_like(U_arr)
180-
for target_regime_name in active_regimes_next_period:
202+
for target_regime_name in complete_targets:
181203
next_states = state_transitions[target_regime_name](
182204
**states_actions_params,
183205
period=period,

src/lcm/regime_building/processing.py

Lines changed: 53 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
import pandas as pd
1010
from dags import concatenate_functions, get_annotations, with_signature
1111
from dags.signature import rename_arguments
12-
from dags.tree import qname_from_tree_path, tree_path_from_qname
12+
from dags.tree import QNAME_DELIMITER, qname_from_tree_path, tree_path_from_qname
1313
from jax import Array
1414
from jax import numpy as jnp
1515

@@ -85,7 +85,7 @@ def process_regimes(
8585
The processed regimes.
8686
8787
"""
88-
states_per_regime: dict[str, set[str]] = {
88+
states_per_regime: dict[RegimeName, set[str]] = {
8989
name: set(regime.states.keys()) for name, regime in regimes.items()
9090
}
9191

@@ -174,7 +174,7 @@ def process_regimes(
174174
def _build_solve_functions(
175175
*,
176176
regime: Regime,
177-
regime_name: str,
177+
regime_name: RegimeName,
178178
nested_transitions: dict[str, dict[str, UserFunction] | UserFunction],
179179
all_grids: MappingProxyType[RegimeName, MappingProxyType[str, Grid]],
180180
regime_params_template: RegimeParamsTemplate,
@@ -262,7 +262,7 @@ def _build_solve_functions(
262262
def _build_simulate_functions(
263263
*,
264264
regime: Regime,
265-
regime_name: str,
265+
regime_name: RegimeName,
266266
nested_transitions: dict[str, dict[str, UserFunction] | UserFunction],
267267
all_grids: MappingProxyType[RegimeName, MappingProxyType[str, Grid]],
268268
regime_params_template: RegimeParamsTemplate,
@@ -508,13 +508,21 @@ def _process_regime_core(
508508
)
509509

510510
# Shock transitions bypass the stub pipeline entirely. Build weight and
511-
# next functions for ALL target regimes directly from each target's grid.
511+
# next functions for reachable target regimes from each target's grid.
512+
# Scope to targets already present in non-shock transitions to avoid
513+
# spurious entries for unreachable regimes.
512514
shock_names = variable_info.query("is_shock").index.tolist()
513-
target_shock_grids: dict[tuple[str, str], _ShockGrid] = { # ty: ignore[invalid-assignment]
514-
(regime, shock): grids[shock]
515+
reachable_targets = {
516+
tree_path_from_qname(k)[0]
517+
for k in flat_nested_transitions
518+
if QNAME_DELIMITER in k
519+
}
520+
target_shock_grids: dict[tuple[RegimeName, str], _ShockGrid] = {
521+
(regime, shock): grid
515522
for regime, grids in all_grids.items()
523+
if regime in reachable_targets
516524
for shock in shock_names
517-
if isinstance(grids.get(shock), _ShockGrid)
525+
if isinstance(grid := grids.get(shock), _ShockGrid)
518526
}
519527
functions |= {
520528
f"weight_{regime}__next_{shock}": _get_weights_func_for_shock(
@@ -567,7 +575,7 @@ def _process_regime_core(
567575
def _extract_transitions_from_regime(
568576
*,
569577
regime: Regime,
570-
states_per_regime: Mapping[str, set[str]],
578+
states_per_regime: Mapping[RegimeName, set[str]],
571579
) -> dict[str, dict[str, UserFunction] | UserFunction]:
572580
"""Extract transitions from `regime.state_transitions` and regime transition.
573581
@@ -600,7 +608,14 @@ def _extract_transitions_from_regime(
600608
{"next_regime": regime.transition},
601609
)
602610

603-
for target_regime_name, target_regime_state_names in states_per_regime.items():
611+
reachable_targets = _get_reachable_targets(
612+
per_target_transitions=per_target_transitions,
613+
simple_transitions=simple_transitions,
614+
states_per_regime=states_per_regime,
615+
)
616+
617+
for target_regime_name in reachable_targets:
618+
target_regime_state_names = states_per_regime[target_regime_name]
604619
target_dict: dict[str, UserFunction] = {}
605620
for state_name in target_regime_state_names:
606621
next_key = f"next_{state_name}"
@@ -616,6 +631,34 @@ def _extract_transitions_from_regime(
616631
return nested
617632

618633

634+
def _get_reachable_targets(
635+
*,
636+
per_target_transitions: dict[str, dict[str, UserFunction]],
637+
simple_transitions: dict[str, UserFunction],
638+
states_per_regime: Mapping[RegimeName, set[str]],
639+
) -> set[RegimeName]:
640+
"""Determine which target regimes need transition entries.
641+
642+
When per-target transitions exist, start from the explicitly named targets
643+
and add any target whose state needs are fully covered by simple
644+
(non-per-target) transitions. Without per-target transitions, all regimes
645+
are reachable.
646+
647+
"""
648+
if not per_target_transitions:
649+
return set(states_per_regime.keys())
650+
651+
targets: set[RegimeName] = set()
652+
for variants in per_target_transitions.values():
653+
targets |= variants.keys()
654+
for target_name, target_states in states_per_regime.items():
655+
if target_name not in targets:
656+
needed = {f"next_{s}" for s in target_states}
657+
if needed and needed.issubset(simple_transitions):
658+
targets.add(target_name)
659+
return targets
660+
661+
619662
def _classify_transitions(
620663
state_transitions: dict[str, UserFunction],
621664
) -> tuple[dict[str, UserFunction], dict[str, dict[str, UserFunction]]]:

src/lcm/utils/error_handling.py

Lines changed: 68 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -61,15 +61,17 @@ def validate_V(
6161
"reasons:\n"
6262
"- The user-defined functions returned invalid values.\n"
6363
"- It is impossible to reach an active regime, resulting in NaN regime\n"
64-
" transition probabilities."
64+
" transition probabilities.\n"
65+
"- A per-target state_transitions dict omits a reachable target\n"
66+
" (non-zero transition probability to an incomplete target)."
6567
)
6668

6769

6870
def validate_regime_transition_probs(
6971
*,
7072
regime_transition_probs: MappingProxyType[str, Array],
71-
active_regimes_next_period: tuple[str, ...],
72-
regime_name: str,
73+
active_regimes_next_period: tuple[RegimeName, ...],
74+
regime_name: RegimeName,
7375
age: ScalarInt | ScalarFloat,
7476
next_age: ScalarInt | ScalarFloat,
7577
state_action_values: MappingProxyType[str, Array] | None = None,
@@ -224,7 +226,7 @@ def validate_regime_transitions_all_periods(
224226
continue
225227

226228
_validate_regime_transition_single(
227-
internal_regime=internal_regime,
229+
internal_regimes=internal_regimes,
228230
regime_params=internal_params[name],
229231
active_regimes_next_period=active_regimes_next_period,
230232
regime_name=name,
@@ -235,10 +237,10 @@ def validate_regime_transitions_all_periods(
235237

236238
def _validate_regime_transition_single(
237239
*,
238-
internal_regime: InternalRegime,
240+
internal_regimes: MappingProxyType[RegimeName, InternalRegime],
239241
regime_params: FlatRegimeParams,
240-
active_regimes_next_period: tuple[str, ...],
241-
regime_name: str,
242+
active_regimes_next_period: tuple[RegimeName, ...],
243+
regime_name: RegimeName,
242244
period: int,
243245
ages: AgeGrid,
244246
) -> None:
@@ -248,6 +250,7 @@ def _validate_regime_transition_single(
248250
variables it accepts, using `jax.vmap` for vectorised evaluation.
249251
250252
"""
253+
internal_regime = internal_regimes[regime_name]
251254
# Non-None guaranteed: only called for non-terminal regimes
252255
regime_transition_func = (
253256
internal_regime.solve_functions.compute_regime_transition_probs
@@ -310,6 +313,64 @@ def _call(
310313
state_action_values=MappingProxyType(point),
311314
)
312315

316+
_validate_no_reachable_incomplete_targets(
317+
internal_regimes=internal_regimes,
318+
regime_transition_probs=regime_transition_probs,
319+
active_regimes_next_period=active_regimes_next_period,
320+
regime_name=regime_name,
321+
age=ages.values[period], # noqa: PD011
322+
)
323+
324+
325+
def _validate_no_reachable_incomplete_targets(
326+
*,
327+
internal_regimes: MappingProxyType[RegimeName, InternalRegime],
328+
regime_transition_probs: MappingProxyType[str, Array],
329+
active_regimes_next_period: tuple[RegimeName, ...],
330+
regime_name: RegimeName,
331+
age: ScalarInt | ScalarFloat,
332+
) -> None:
333+
"""Check that targets with incomplete stochastic transitions are unreachable.
334+
335+
A target is "incomplete" from the source regime if the source's
336+
`transitions[target_regime_name]` does not cover all of the target's
337+
stochastic state needs. Such targets must have zero transition
338+
probability, otherwise the continuation value cannot be computed. This
339+
includes self-transitions (regime reaches itself): omitting the
340+
self-entry in a per-target dict is a common user error.
341+
342+
"""
343+
solve_functions = internal_regimes[regime_name].solve_functions
344+
transitions = solve_functions.transitions
345+
stochastic_names = solve_functions.stochastic_transition_names
346+
347+
for target_regime_name in active_regimes_next_period:
348+
target_regime = internal_regimes[target_regime_name]
349+
target_state_names = tuple(target_regime.variable_info.query("is_state").index)
350+
needs = {
351+
f"next_{s}" for s in target_state_names if f"next_{s}" in stochastic_names
352+
}
353+
if not needs:
354+
continue
355+
if target_regime_name in transitions and needs.issubset(
356+
transitions[target_regime_name]
357+
):
358+
continue
359+
if not jnp.any(regime_transition_probs[target_regime_name] > 0):
360+
continue
361+
missing = sorted(needs - set(transitions.get(target_regime_name, {})))
362+
if target_regime_name not in transitions:
363+
missing = sorted(f"next_{s}" for s in target_state_names)
364+
raise InvalidRegimeTransitionProbabilitiesError(
365+
f"Regime '{regime_name}' at age {age} has positive transition "
366+
f"probability to '{target_regime_name}', but '{regime_name}' "
367+
f"does not provide state transition(s) for: {missing}. Extend "
368+
f"`state_transitions` in '{regime_name}' to cover "
369+
f"'{target_regime_name}' (via a per-target dict if the "
370+
f"transition differs by target), or ensure "
371+
f"'{target_regime_name}' is unreachable."
372+
)
373+
313374

314375
def _get_func_indexing_params(
315376
*,

0 commit comments

Comments
 (0)