Skip to content

Commit 20669fd

Browse files
mj023pre-commit-ci[bot]hmgaudeckerclaude
authored
Add batched productmap (#280)
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Hans-Martin von Gaudecker <hmgaudecker@gmail.com> Co-authored-by: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent 542e0d9 commit 20669fd

24 files changed

Lines changed: 1047 additions & 834 deletions

pixi.lock

Lines changed: 742 additions & 695 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

src/lcm/grids/base.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,16 @@
66
class Grid(ABC):
77
"""LCM Grid base class."""
88

9+
@property
10+
@abstractmethod
11+
def batch_size(self) -> int:
12+
"""Size of the batches looped over during the solution.
13+
14+
`ContinuousGrid` overrides this via its dataclass field.
15+
`DiscreteGrid` overrides this via its own property.
16+
17+
"""
18+
919
@abstractmethod
1020
def to_jax(self) -> Int1D | Float1D:
1121
"""Convert the grid to a Jax array."""

src/lcm/grids/continuous.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
)
1717

1818

19+
@dataclass(frozen=True, kw_only=True)
1920
class ContinuousGrid(Grid):
2021
"""Base class for grids representing continuous values with coordinate lookup.
2122
@@ -24,6 +25,9 @@ class ContinuousGrid(Grid):
2425
2526
"""
2627

28+
batch_size: int = 0
29+
"""Size of the batches that are looped over during the solution."""
30+
2731
@overload
2832
def get_coordinate(self, value: ScalarFloat) -> ScalarFloat: ...
2933
@overload

src/lcm/grids/discrete.py

Lines changed: 19 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -6,15 +6,26 @@
66
from lcm.utils.containers import get_field_names_and_values
77

88

9-
class _DiscreteGridBase(Grid):
10-
"""Base class for discrete grids: categories, codes, and JAX conversion."""
9+
class DiscreteGrid(Grid):
10+
"""A discrete grid defining the outcome space of a categorical variable.
11+
12+
Args:
13+
category_class: The category class representing the grid categories. Must
14+
be a dataclass with fields that have unique int values.
15+
16+
Raises:
17+
GridInitializationError: If the `category_class` is not a dataclass with int
18+
fields.
19+
20+
"""
1121

12-
def __init__(self, category_class: type) -> None:
22+
def __init__(self, category_class: type, batch_size: int = 0) -> None:
1323
_validate_discrete_grid(category_class)
1424
names_and_values = get_field_names_and_values(category_class)
1525
self.__categories = tuple(names_and_values.keys())
1626
self.__codes = tuple(names_and_values.values())
1727
self.__ordered: bool = getattr(category_class, "_ordered", False)
28+
self.__batch_size: int = batch_size
1829

1930
@property
2031
def categories(self) -> tuple[str, ...]:
@@ -31,23 +42,11 @@ def ordered(self) -> bool:
3142
"""Return whether the categories have a meaningful ordering."""
3243
return self.__ordered
3344

45+
@property
46+
def batch_size(self) -> int:
47+
"""Return batch size during solution."""
48+
return self.__batch_size
49+
3450
def to_jax(self) -> Int1D:
3551
"""Convert the grid to a Jax array."""
3652
return jnp.array(self.codes)
37-
38-
39-
class DiscreteGrid(_DiscreteGridBase):
40-
"""A discrete grid defining the outcome space of a categorical variable.
41-
42-
Args:
43-
category_class: The category class representing the grid categories. Must
44-
be a dataclass with fields that have unique int values.
45-
46-
Raises:
47-
GridInitializationError: If the `category_class` is not a dataclass with int
48-
fields.
49-
50-
"""
51-
52-
def __init__(self, category_class: type) -> None:
53-
super().__init__(category_class)

src/lcm/model.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -385,7 +385,7 @@ def _check_leaf(value: object, path: str) -> None:
385385
type_name = type(value).__module__ + "." + type(value).__name__
386386
msg = (
387387
f"Parameter '{path}' is a {type_name} (shape {value.shape}). "
388-
f"Use jax.numpy.array() or pass a pd.Series with a named index."
388+
f"Use jnp.array() or pass a pd.Series with a named index."
389389
)
390390
raise InvalidParamsError(msg)
391391
type_name = type(value).__module__ + "." + type(value).__name__

src/lcm/pandas_utils.py

Lines changed: 16 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -357,7 +357,7 @@ def array_from_series(
357357
display_params = ["age" if p == "period" else p for p in indexing_params]
358358

359359
level_mappings = _build_level_mappings_for_param(
360-
indexing_params=display_params, all_grids=grids, ages=model.ages
360+
indexing_params=display_params, grids=grids, ages=model.ages
361361
)
362362

363363
# Append outcome axis for transition probability arrays (next_* functions
@@ -368,7 +368,7 @@ def array_from_series(
368368
]
369369
if next_levels:
370370
outcome_mapping = _build_outcome_mapping(
371-
func_name=func_name, all_grids=grids, model=model
371+
func_name=func_name, grids=grids, model=model
372372
)
373373
level_mappings = (*level_mappings, outcome_mapping)
374374

@@ -566,8 +566,8 @@ class _LevelMapping:
566566
size: int
567567
"""Number of positions along this axis."""
568568

569-
label_to_index: Callable[[object], int]
570-
"""Map a single label value to its integer index."""
569+
get_code_from_label: Callable[[str], int]
570+
"""Return the integer code for a label."""
571571

572572
valid_labels: tuple[str, ...] = ()
573573
"""Valid label names, for error messages. Empty for age levels."""
@@ -578,7 +578,7 @@ def _age_level_mapping(ages: AgeGrid) -> _LevelMapping:
578578
return _LevelMapping(
579579
name="age",
580580
size=ages.n_periods,
581-
label_to_index=ages.age_to_period, # ty: ignore[invalid-argument-type]
581+
get_code_from_label=ages.age_to_period, # ty: ignore[invalid-argument-type]
582582
)
583583

584584

@@ -597,23 +597,23 @@ def _grid_level_mapping(*, name: str, grid: DiscreteGrid) -> _LevelMapping:
597597
return _LevelMapping(
598598
name=name,
599599
size=len(grid.categories),
600-
label_to_index=label_to_code.__getitem__,
600+
get_code_from_label=label_to_code.__getitem__,
601601
valid_labels=grid.categories,
602602
)
603603

604604

605605
def _build_level_mappings_for_param(
606606
*,
607607
indexing_params: list[str],
608-
all_grids: dict[str, DiscreteGrid],
608+
grids: dict[str, DiscreteGrid],
609609
ages: AgeGrid,
610610
) -> tuple[_LevelMapping, ...]:
611611
"""Build level mappings for `array_from_series` from indexing params.
612612
613613
Args:
614614
indexing_params: Parameter names in output axis order, with
615615
`"period"` already replaced by `"age"`.
616-
all_grids: Categorical grid lookup.
616+
grids: Categorical grid lookup.
617617
ages: The model's `AgeGrid`.
618618
619619
Returns:
@@ -624,12 +624,12 @@ def _build_level_mappings_for_param(
624624
for param in indexing_params:
625625
if param == "age":
626626
mappings.append(_age_level_mapping(ages))
627-
elif param in all_grids:
628-
mappings.append(_grid_level_mapping(name=param, grid=all_grids[param]))
627+
elif param in grids:
628+
mappings.append(_grid_level_mapping(name=param, grid=grids[param]))
629629
else:
630630
msg = (
631631
f"Unrecognised indexing parameter '{param}'. Expected 'age' "
632-
f"or a discrete grid name ({sorted(all_grids)}). If "
632+
f"or a discrete grid name ({sorted(grids)}). If "
633633
f"'{param}' is a DAG function output, pass "
634634
f'derived_categoricals={{"{param}": DiscreteGrid(...)}} '
635635
f"to solve() / simulate()."
@@ -641,7 +641,7 @@ def _build_level_mappings_for_param(
641641
def _build_outcome_mapping(
642642
*,
643643
func_name: str,
644-
all_grids: dict[str, DiscreteGrid],
644+
grids: dict[str, DiscreteGrid],
645645
model: Model,
646646
) -> _LevelMapping:
647647
"""Build a `_LevelMapping` for the outcome axis of a `next_*` function.
@@ -651,7 +651,7 @@ def _build_outcome_mapping(
651651
652652
Args:
653653
func_name: Function name starting with `"next_"`.
654-
all_grids: Categorical grid lookup.
654+
grids: Categorical grid lookup.
655655
model: The LCM Model instance.
656656
657657
Returns:
@@ -663,13 +663,13 @@ def _build_outcome_mapping(
663663
return _LevelMapping(
664664
name="next_regime",
665665
size=len(regime_ids),
666-
label_to_index=regime_ids.__getitem__, # ty: ignore[invalid-argument-type]
666+
get_code_from_label=regime_ids.__getitem__,
667667
valid_labels=tuple(regime_ids),
668668
)
669669

670670
path = tree_path_from_qname(func_name)
671671
state_name = path[0].removeprefix("next_")
672-
return _grid_level_mapping(name=f"next_{state_name}", grid=all_grids[state_name])
672+
return _grid_level_mapping(name=f"next_{state_name}", grid=grids[state_name])
673673

674674

675675
def _scatter_series(
@@ -744,7 +744,7 @@ def _map_level(*, mapping: _LevelMapping, level_values: pd.Index) -> np.ndarray:
744744
raise ValueError(msg)
745745

746746
try:
747-
return np.array([mapping.label_to_index(v) for v in level_values])
747+
return np.array([mapping.get_code_from_label(v) for v in level_values])
748748
except ValueError:
749749
# Age levels: age_to_period raises ValueError with a good message
750750
raise

src/lcm/regime_building/Q_and_F.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -109,11 +109,13 @@ def get_Q_and_F(
109109
- set(target_transitions)
110110
- {V_arr_name}
111111
)
112+
stochastic_variables = tuple(
113+
key for key in target_transitions if key in stochastic_transition_names
114+
)
112115
next_V[target_regime_name] = productmap(
113116
func=next_V_interpolator,
114-
variables=tuple(
115-
key for key in target_transitions if key in stochastic_transition_names
116-
),
117+
variables=stochastic_variables,
118+
batch_sizes=dict.fromkeys(stochastic_variables, 0),
117119
)
118120

119121
# ----------------------------------------------------------------------------------
@@ -342,7 +344,10 @@ def _outer(**kwargs: Float1D) -> FloatND:
342344
weights = jnp.array(list(kwargs.values()))
343345
return jnp.prod(weights)
344346

345-
return productmap(func=_outer, variables=tuple(arg_names))
347+
variables = tuple(arg_names)
348+
return productmap(
349+
func=_outer, variables=variables, batch_sizes=dict.fromkeys(variables, 0)
350+
)
346351

347352

348353
def _get_U_and_F(

src/lcm/regime_building/max_Q_over_a.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
def get_max_Q_over_a(
2222
*,
2323
Q_and_F: Callable[..., tuple[FloatND, BoolND]],
24+
batch_sizes: dict[str, int],
2425
action_names: tuple[str, ...],
2526
state_names: tuple[str, ...],
2627
) -> MaxQOverAFunction:
@@ -47,6 +48,8 @@ def get_max_Q_over_a(
4748
Q_and_F: A function that takes a state-action combination and returns the action
4849
value of that combination and whether the state-action combination is
4950
feasible.
51+
batch_sizes: Mapping of state variable names to batch sizes for the outer
52+
productmap over states. A batch size of 0 means no batching.
5053
action_names: Tuple of action variable names.
5154
state_names: Tuple of state names.
5255
@@ -60,9 +63,12 @@ def get_max_Q_over_a(
6063
Q_and_F=Q_and_F, action_names=action_names, state_names=state_names
6164
)
6265

66+
# Actions are the inner optimization axis — batching applies only to the
67+
# outer state loop.
6368
Q_and_F = productmap(
6469
func=Q_and_F,
6570
variables=action_names,
71+
batch_sizes=dict.fromkeys(action_names, 0),
6672
)
6773

6874
@with_signature(
@@ -80,7 +86,7 @@ def max_Q_over_a(
8086
)
8187
return Q_arr.max(where=F_arr, initial=-jnp.inf)
8288

83-
return productmap(func=max_Q_over_a, variables=state_names)
89+
return productmap(func=max_Q_over_a, variables=state_names, batch_sizes=batch_sizes)
8490

8591

8692
def get_argmax_and_max_Q_over_a(
@@ -130,6 +136,7 @@ def get_argmax_and_max_Q_over_a(
130136
Q_and_F = productmap(
131137
func=Q_and_F,
132138
variables=action_names,
139+
batch_sizes=dict.fromkeys(action_names, 0),
133140
)
134141

135142
@with_signature(

src/lcm/regime_building/processing.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,8 @@ def process_regimes(
102102
)
103103
all_grids = MappingProxyType({n: get_grids(r) for n, r in regimes.items()})
104104

105+
_fail_if_action_has_batch_size(regimes)
106+
105107
regime_to_v_interpolation_info = MappingProxyType(
106108
{n: create_v_interpolation_info(r) for n, r in regimes.items()}
107109
)
@@ -243,6 +245,7 @@ def _build_solve_functions(
243245
max_Q_over_a = _build_max_Q_over_a_per_period(
244246
state_action_space=state_action_space,
245247
Q_and_F_functions=Q_and_F_functions,
248+
grids=all_grids[regime_name],
246249
enable_jit=enable_jit,
247250
)
248251

@@ -1256,13 +1259,19 @@ def _build_max_Q_over_a_per_period(
12561259
*,
12571260
state_action_space: StateActionSpace,
12581261
Q_and_F_functions: MappingProxyType[int, QAndFFunction],
1262+
grids: MappingProxyType[str, Grid],
12591263
enable_jit: bool,
12601264
) -> MappingProxyType[int, MaxQOverAFunction]:
12611265
"""Build max-Q-over-a closures for each period."""
12621266
result = {}
12631267
for period, Q_and_F in Q_and_F_functions.items():
12641268
func = get_max_Q_over_a(
12651269
Q_and_F=Q_and_F,
1270+
batch_sizes={
1271+
name: grid.batch_size
1272+
for name, grid in grids.items()
1273+
if name in state_action_space.state_names
1274+
},
12661275
action_names=state_action_space.action_names,
12671276
state_names=state_action_space.state_names,
12681277
)
@@ -1323,3 +1332,22 @@ def _build_next_state_vmapped(
13231332
)
13241333

13251334
return jax.jit(next_state_vmapped) if enable_jit else next_state_vmapped
1335+
1336+
1337+
def _fail_if_action_has_batch_size(regimes: Mapping[str, Regime]) -> None:
1338+
"""Raise if any action grid has a non-zero batch_size.
1339+
1340+
Batching applies only to the outer state loop during solving, not to the
1341+
inner action optimization. A non-zero batch_size on an action grid would be
1342+
silently ignored, so we reject it early.
1343+
1344+
"""
1345+
for regime_name, regime in regimes.items():
1346+
for action_name, grid in regime.actions.items():
1347+
if grid.batch_size != 0:
1348+
msg = (
1349+
f"batch_size > 0 is not supported on action grids. Only state "
1350+
f"grids can be batched. Found batch_size={grid.batch_size} on "
1351+
f"action '{action_name}' in regime '{regime_name}'."
1352+
)
1353+
raise ValueError(msg)

src/lcm/shocks/_base.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,9 @@ class _ShockGrid(ContinuousGrid):
4545
@property
4646
def _param_field_names(self) -> tuple[str, ...]:
4747
"""Names of distribution-specific parameters."""
48-
return tuple(f.name for f in fields(self) if f.name != "n_points")
48+
return tuple(
49+
f.name for f in fields(self) if f.name not in {"n_points", "batch_size"}
50+
)
4951

5052
@property
5153
def params(self) -> MappingProxyType[str, float]:

0 commit comments

Comments
 (0)