@@ -357,7 +357,7 @@ def array_from_series(
357357 display_params = ["age" if p == "period" else p for p in indexing_params ]
358358
359359 level_mappings = _build_level_mappings_for_param (
360- indexing_params = display_params , all_grids = grids , ages = model .ages
360+ indexing_params = display_params , grids = grids , ages = model .ages
361361 )
362362
363363 # Append outcome axis for transition probability arrays (next_* functions
@@ -368,7 +368,7 @@ def array_from_series(
368368 ]
369369 if next_levels :
370370 outcome_mapping = _build_outcome_mapping (
371- func_name = func_name , all_grids = grids , model = model
371+ func_name = func_name , grids = grids , model = model
372372 )
373373 level_mappings = (* level_mappings , outcome_mapping )
374374
@@ -566,8 +566,8 @@ class _LevelMapping:
566566 size : int
567567 """Number of positions along this axis."""
568568
569- label_to_index : Callable [[object ], int ]
570- """Map a single label value to its integer index ."""
569+ get_code_from_label : Callable [[str ], int ]
570+ """Return the integer code for a label ."""
571571
572572 valid_labels : tuple [str , ...] = ()
573573 """Valid label names, for error messages. Empty for age levels."""
@@ -578,7 +578,7 @@ def _age_level_mapping(ages: AgeGrid) -> _LevelMapping:
578578 return _LevelMapping (
579579 name = "age" ,
580580 size = ages .n_periods ,
581- label_to_index = ages .age_to_period , # ty: ignore[invalid-argument-type]
581+ get_code_from_label = ages .age_to_period , # ty: ignore[invalid-argument-type]
582582 )
583583
584584
@@ -597,23 +597,23 @@ def _grid_level_mapping(*, name: str, grid: DiscreteGrid) -> _LevelMapping:
597597 return _LevelMapping (
598598 name = name ,
599599 size = len (grid .categories ),
600- label_to_index = label_to_code .__getitem__ ,
600+ get_code_from_label = label_to_code .__getitem__ ,
601601 valid_labels = grid .categories ,
602602 )
603603
604604
605605def _build_level_mappings_for_param (
606606 * ,
607607 indexing_params : list [str ],
608- all_grids : dict [str , DiscreteGrid ],
608+ grids : dict [str , DiscreteGrid ],
609609 ages : AgeGrid ,
610610) -> tuple [_LevelMapping , ...]:
611611 """Build level mappings for `array_from_series` from indexing params.
612612
613613 Args:
614614 indexing_params: Parameter names in output axis order, with
615615 `"period"` already replaced by `"age"`.
616- all_grids : Categorical grid lookup.
616+ grids : Categorical grid lookup.
617617 ages: The model's `AgeGrid`.
618618
619619 Returns:
@@ -624,12 +624,12 @@ def _build_level_mappings_for_param(
624624 for param in indexing_params :
625625 if param == "age" :
626626 mappings .append (_age_level_mapping (ages ))
627- elif param in all_grids :
628- mappings .append (_grid_level_mapping (name = param , grid = all_grids [param ]))
627+ elif param in grids :
628+ mappings .append (_grid_level_mapping (name = param , grid = grids [param ]))
629629 else :
630630 msg = (
631631 f"Unrecognised indexing parameter '{ param } '. Expected 'age' "
632- f"or a discrete grid name ({ sorted (all_grids )} ). If "
632+ f"or a discrete grid name ({ sorted (grids )} ). If "
633633 f"'{ param } ' is a DAG function output, pass "
634634 f'derived_categoricals={{"{ param } ": DiscreteGrid(...)}} '
635635 f"to solve() / simulate()."
@@ -641,7 +641,7 @@ def _build_level_mappings_for_param(
641641def _build_outcome_mapping (
642642 * ,
643643 func_name : str ,
644- all_grids : dict [str , DiscreteGrid ],
644+ grids : dict [str , DiscreteGrid ],
645645 model : Model ,
646646) -> _LevelMapping :
647647 """Build a `_LevelMapping` for the outcome axis of a `next_*` function.
@@ -651,7 +651,7 @@ def _build_outcome_mapping(
651651
652652 Args:
653653 func_name: Function name starting with `"next_"`.
654- all_grids : Categorical grid lookup.
654+ grids : Categorical grid lookup.
655655 model: The LCM Model instance.
656656
657657 Returns:
@@ -663,13 +663,13 @@ def _build_outcome_mapping(
663663 return _LevelMapping (
664664 name = "next_regime" ,
665665 size = len (regime_ids ),
666- label_to_index = regime_ids .__getitem__ , # ty: ignore[invalid-argument-type]
666+ get_code_from_label = regime_ids .__getitem__ ,
667667 valid_labels = tuple (regime_ids ),
668668 )
669669
670670 path = tree_path_from_qname (func_name )
671671 state_name = path [0 ].removeprefix ("next_" )
672- return _grid_level_mapping (name = f"next_{ state_name } " , grid = all_grids [state_name ])
672+ return _grid_level_mapping (name = f"next_{ state_name } " , grid = grids [state_name ])
673673
674674
675675def _scatter_series (
@@ -744,7 +744,7 @@ def _map_level(*, mapping: _LevelMapping, level_values: pd.Index) -> np.ndarray:
744744 raise ValueError (msg )
745745
746746 try :
747- return np .array ([mapping .label_to_index (v ) for v in level_values ])
747+ return np .array ([mapping .get_code_from_label (v ) for v in level_values ])
748748 except ValueError :
749749 # Age levels: age_to_period raises ValueError with a good message
750750 raise
0 commit comments