Skip to content

Commit 49cb0e5

Browse files
hmgaudeckerclaude
andauthored
Fix outcome axis for cross-grid stochastic transitions (#321)
Co-authored-by: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent 303df59 commit 49cb0e5

2 files changed

Lines changed: 118 additions & 0 deletions

File tree

src/lcm/pandas_utils.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -379,6 +379,7 @@ def array_from_series(
379379
outcome_mapping = _build_outcome_mapping(
380380
func_name=func_name,
381381
grids=grids,
382+
regimes=regimes,
382383
regime_names_to_ids=regime_names_to_ids,
383384
)
384385
level_mappings = (*level_mappings, outcome_mapping)
@@ -607,16 +608,20 @@ def _build_outcome_mapping(
607608
*,
608609
func_name: str,
609610
grids: dict[str, DiscreteGrid],
611+
regimes: Mapping[str, Regime],
610612
regime_names_to_ids: RegimeNamesToIds,
611613
) -> _LevelMapping:
612614
"""Build a `_LevelMapping` for the outcome axis of a `next_*` function.
613615
614616
For state transitions (e.g. `"next_partner"`), look up the state grid.
617+
For per-target transitions (e.g. `"next_health__post65"`), use the target
618+
regime's grid for the outcome axis.
615619
For regime transitions (`"next_regime"`), use `regime_names_to_ids`.
616620
617621
Args:
618622
func_name: Function name starting with `"next_"`.
619623
grids: Categorical grid lookup.
624+
regimes: Mapping of regime names to user Regime instances.
620625
regime_names_to_ids: Immutable mapping from regime names to integer
621626
indices.
622627
@@ -634,6 +639,17 @@ def _build_outcome_mapping(
634639

635640
path = tree_path_from_qname(func_name)
636641
state_name = path[0].removeprefix("next_")
642+
643+
# Per-target transitions (e.g. "next_health__post65") must use the TARGET
644+
# regime's grid for the outcome axis, not the source regime's grid.
645+
if len(path) > 1:
646+
target_regime_name = path[1]
647+
target_regime = regimes.get(target_regime_name)
648+
if target_regime is not None and state_name in target_regime.states:
649+
target_grid = target_regime.states[state_name]
650+
if isinstance(target_grid, DiscreteGrid):
651+
return _grid_level_mapping(name=f"next_{state_name}", grid=target_grid)
652+
637653
return _grid_level_mapping(name=f"next_{state_name}", grid=grids[state_name])
638654

639655

tests/test_pandas_utils.py

Lines changed: 102 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1614,6 +1614,7 @@ def test_build_outcome_mapping_qualified_func_name() -> None:
16141614
result = _build_outcome_mapping(
16151615
func_name="next_health__working",
16161616
grids=grids,
1617+
regimes=model.regimes,
16171618
regime_names_to_ids=model.regime_names_to_ids,
16181619
)
16191620
assert result.size == 2
@@ -1787,6 +1788,107 @@ class WrongPartner:
17871788
)
17881789

17891790

1791+
def test_convert_series_cross_grid_transition() -> None:
1792+
"""Outcome axis must use the TARGET regime's grid, not the source's.
1793+
1794+
When a per-target MarkovTransition crosses grid sizes (e.g. 3-state
1795+
source → 2-state target), the converted array's last dimension must
1796+
match the target's grid size (2), not the source's (3).
1797+
"""
1798+
from lcm import MarkovTransition # noqa: PLC0415
1799+
from lcm.typing import DiscreteState, FloatND, Period # noqa: PLC0415
1800+
1801+
@categorical(ordered=True)
1802+
class _HealthPre:
1803+
disabled: int
1804+
bad: int
1805+
good: int
1806+
1807+
@categorical(ordered=True)
1808+
class _HealthPost:
1809+
bad: int
1810+
good: int
1811+
1812+
@categorical(ordered=False)
1813+
class _RId:
1814+
pre65: int
1815+
post65: int
1816+
1817+
def _health_probs_same(
1818+
period: Period, health: DiscreteState, health_trans_probs: FloatND
1819+
) -> FloatND:
1820+
return health_trans_probs[period, health]
1821+
1822+
def _health_probs_cross(
1823+
period: Period, health: DiscreteState, health_trans_probs_cross: FloatND
1824+
) -> FloatND:
1825+
return health_trans_probs_cross[period, health]
1826+
1827+
pre65 = Regime(
1828+
states={
1829+
"health": DiscreteGrid(_HealthPre),
1830+
"wealth": LinSpacedGrid(start=0, stop=10, n_points=5),
1831+
},
1832+
state_transitions={
1833+
"health": {
1834+
"pre65": MarkovTransition(_health_probs_same),
1835+
"post65": MarkovTransition(_health_probs_cross),
1836+
},
1837+
"wealth": lambda wealth: wealth,
1838+
},
1839+
functions={"utility": lambda health, wealth: wealth + health},
1840+
transition=lambda age: jnp.where(age >= 1, _RId.post65, _RId.pre65),
1841+
active=lambda age: age < 1,
1842+
)
1843+
post65 = Regime(
1844+
transition=None,
1845+
states={
1846+
"health": DiscreteGrid(_HealthPost),
1847+
"wealth": LinSpacedGrid(start=0, stop=10, n_points=5),
1848+
},
1849+
functions={"utility": lambda health, wealth: wealth + health},
1850+
)
1851+
model = Model(
1852+
regimes={"pre65": pre65, "post65": post65},
1853+
ages=AgeGrid(start=0, stop=1, step="Y"),
1854+
regime_id_class=_RId,
1855+
)
1856+
1857+
# Cross-grid transition probs: 3 source states → 2 target states
1858+
index_cross = pd.MultiIndex.from_tuples(
1859+
[
1860+
(0.0, "disabled", "bad"),
1861+
(0.0, "disabled", "good"),
1862+
(0.0, "bad", "bad"),
1863+
(0.0, "bad", "good"),
1864+
(0.0, "good", "bad"),
1865+
(0.0, "good", "good"),
1866+
],
1867+
names=["age", "health", "next_health"],
1868+
)
1869+
sr_cross = pd.Series([0.65, 0.35, 0.81, 0.19, 0.06, 0.94], index=index_cross)
1870+
1871+
params = {
1872+
"pre65": {
1873+
"to_post65_next_health": {"health_trans_probs_cross": sr_cross},
1874+
},
1875+
}
1876+
internal = broadcast_to_template(
1877+
params=params, template=model.get_params_template(), required=False
1878+
)
1879+
result = convert_series_in_params(
1880+
internal_params=internal,
1881+
regimes=model.regimes,
1882+
ages=model.ages,
1883+
regime_names_to_ids=model.regime_names_to_ids,
1884+
)
1885+
1886+
arr = result["pre65"]["to_post65_next_health__health_trans_probs_cross"]
1887+
# Shape: (n_ages=2, n_source_health=3, n_target_health=2)
1888+
# n_ages=2 because AgeGrid has ages [0, 1]; missing age 1 is NaN-filled.
1889+
assert arr.shape == (2, 3, 2) # ty: ignore[unresolved-attribute]
1890+
1891+
17901892
def test_resolve_categoricals_includes_derived_when_no_regime_name() -> None:
17911893
"""derived_categoricals are included even when regime_name is None."""
17921894
from lcm.pandas_utils import _resolve_categoricals # noqa: PLC0415

0 commit comments

Comments
 (0)