@@ -1614,6 +1614,7 @@ def test_build_outcome_mapping_qualified_func_name() -> None:
16141614 result = _build_outcome_mapping (
16151615 func_name = "next_health__working" ,
16161616 grids = grids ,
1617+ regimes = model .regimes ,
16171618 regime_names_to_ids = model .regime_names_to_ids ,
16181619 )
16191620 assert result .size == 2
@@ -1787,6 +1788,107 @@ class WrongPartner:
17871788 )
17881789
17891790
1791+ def test_convert_series_cross_grid_transition () -> None :
1792+ """Outcome axis must use the TARGET regime's grid, not the source's.
1793+
1794+ When a per-target MarkovTransition crosses grid sizes (e.g. 3-state
1795+ source → 2-state target), the converted array's last dimension must
1796+ match the target's grid size (2), not the source's (3).
1797+ """
1798+ from lcm import MarkovTransition # noqa: PLC0415
1799+ from lcm .typing import DiscreteState , FloatND , Period # noqa: PLC0415
1800+
1801+ @categorical (ordered = True )
1802+ class _HealthPre :
1803+ disabled : int
1804+ bad : int
1805+ good : int
1806+
1807+ @categorical (ordered = True )
1808+ class _HealthPost :
1809+ bad : int
1810+ good : int
1811+
1812+ @categorical (ordered = False )
1813+ class _RId :
1814+ pre65 : int
1815+ post65 : int
1816+
1817+ def _health_probs_same (
1818+ period : Period , health : DiscreteState , health_trans_probs : FloatND
1819+ ) -> FloatND :
1820+ return health_trans_probs [period , health ]
1821+
1822+ def _health_probs_cross (
1823+ period : Period , health : DiscreteState , health_trans_probs_cross : FloatND
1824+ ) -> FloatND :
1825+ return health_trans_probs_cross [period , health ]
1826+
1827+ pre65 = Regime (
1828+ states = {
1829+ "health" : DiscreteGrid (_HealthPre ),
1830+ "wealth" : LinSpacedGrid (start = 0 , stop = 10 , n_points = 5 ),
1831+ },
1832+ state_transitions = {
1833+ "health" : {
1834+ "pre65" : MarkovTransition (_health_probs_same ),
1835+ "post65" : MarkovTransition (_health_probs_cross ),
1836+ },
1837+ "wealth" : lambda wealth : wealth ,
1838+ },
1839+ functions = {"utility" : lambda health , wealth : wealth + health },
1840+ transition = lambda age : jnp .where (age >= 1 , _RId .post65 , _RId .pre65 ),
1841+ active = lambda age : age < 1 ,
1842+ )
1843+ post65 = Regime (
1844+ transition = None ,
1845+ states = {
1846+ "health" : DiscreteGrid (_HealthPost ),
1847+ "wealth" : LinSpacedGrid (start = 0 , stop = 10 , n_points = 5 ),
1848+ },
1849+ functions = {"utility" : lambda health , wealth : wealth + health },
1850+ )
1851+ model = Model (
1852+ regimes = {"pre65" : pre65 , "post65" : post65 },
1853+ ages = AgeGrid (start = 0 , stop = 1 , step = "Y" ),
1854+ regime_id_class = _RId ,
1855+ )
1856+
1857+ # Cross-grid transition probs: 3 source states → 2 target states
1858+ index_cross = pd .MultiIndex .from_tuples (
1859+ [
1860+ (0.0 , "disabled" , "bad" ),
1861+ (0.0 , "disabled" , "good" ),
1862+ (0.0 , "bad" , "bad" ),
1863+ (0.0 , "bad" , "good" ),
1864+ (0.0 , "good" , "bad" ),
1865+ (0.0 , "good" , "good" ),
1866+ ],
1867+ names = ["age" , "health" , "next_health" ],
1868+ )
1869+ sr_cross = pd .Series ([0.65 , 0.35 , 0.81 , 0.19 , 0.06 , 0.94 ], index = index_cross )
1870+
1871+ params = {
1872+ "pre65" : {
1873+ "to_post65_next_health" : {"health_trans_probs_cross" : sr_cross },
1874+ },
1875+ }
1876+ internal = broadcast_to_template (
1877+ params = params , template = model .get_params_template (), required = False
1878+ )
1879+ result = convert_series_in_params (
1880+ internal_params = internal ,
1881+ regimes = model .regimes ,
1882+ ages = model .ages ,
1883+ regime_names_to_ids = model .regime_names_to_ids ,
1884+ )
1885+
1886+ arr = result ["pre65" ]["to_post65_next_health__health_trans_probs_cross" ]
1887+ # Shape: (n_ages=2, n_source_health=3, n_target_health=2)
1888+ # n_ages=2 because AgeGrid has ages [0, 1]; missing age 1 is NaN-filled.
1889+ assert arr .shape == (2 , 3 , 2 ) # ty: ignore[unresolved-attribute]
1890+
1891+
17901892def test_resolve_categoricals_includes_derived_when_no_regime_name () -> None :
17911893 """derived_categoricals are included even when regime_name is None."""
17921894 from lcm .pandas_utils import _resolve_categoricals # noqa: PLC0415
0 commit comments