Skip to content

Commit 1b2baa3

Browse files
hmgaudeckerclaude
andcommitted
Add parallelization across multiple devices during solve (#346)
Add a `distributed=True` flag on `DiscreteGrid` to shard the grid across JAX devices, thread the distribution pattern through `solve_brute._get_regime_V_shapes_and_shardings`, and validate the device-count match at runtime via a new check in `InternalRegime.state_action_space`. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
1 parent 99a5e31 commit 1b2baa3

6 files changed

Lines changed: 186 additions & 20 deletions

File tree

src/lcm/grids/base.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,16 @@ def batch_size(self) -> int:
1616
1717
"""
1818

19+
@property
20+
@abstractmethod
21+
def distributed(self) -> bool:
22+
"""Whether to distribute the grid over the available devices.
23+
24+
`ContinuousGrid` overrides this via its dataclass field.
25+
`DiscreteGrid` overrides this via its own property.
26+
27+
"""
28+
1929
@abstractmethod
2030
def to_jax(self) -> Int1D | Float1D:
2131
"""Convert the grid to a Jax array."""

src/lcm/grids/continuous.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,8 @@ class ContinuousGrid(Grid):
2929

3030
batch_size: int = 0
3131
"""Size of the batches that are looped over during the solution."""
32+
distributed: bool = False
33+
"""Size of the batches that are looped over during the solution."""
3234

3335
@overload
3436
def get_coordinate(self, value: ScalarFloat) -> ScalarFloat: ...

src/lcm/grids/discrete.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,13 +19,16 @@ class DiscreteGrid(Grid):
1919
2020
"""
2121

22-
def __init__(self, category_class: type, batch_size: int = 0) -> None:
22+
def __init__(
23+
self, category_class: type, batch_size: int = 0, distributed=False
24+
) -> None:
2325
_validate_discrete_grid(category_class)
2426
names_and_values = get_field_names_and_values(category_class)
2527
self.__categories = tuple(names_and_values.keys())
2628
self.__codes = tuple(names_and_values.values())
2729
self.__ordered: bool = getattr(category_class, "_ordered", False)
2830
self.__batch_size: int = batch_size
31+
self.__distributed: bool = distributed
2932

3033
@property
3134
def categories(self) -> tuple[str, ...]:
@@ -47,6 +50,11 @@ def batch_size(self) -> int:
4750
"""Return batch size during solution."""
4851
return self.__batch_size
4952

53+
@property
54+
def distributed(self) -> bool:
55+
"""Return batch size during solution."""
56+
return self.__distributed
57+
5058
def to_jax(self) -> Int1D:
5159
"""Convert the grid to a Jax array.
5260

src/lcm/interfaces.py

Lines changed: 57 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,15 @@
11
import dataclasses
22
from collections.abc import Callable
3+
from functools import reduce
4+
from operator import mul
35
from types import MappingProxyType
46
from typing import cast
57

8+
import jax
69
import pandas as pd
710
from jax import Array
811

12+
from lcm.exceptions import PyLCMError
913
from lcm.grids import Grid, IrregSpacedGrid
1014
from lcm.shocks import _ShockGrid
1115
from lcm.typing import (
@@ -294,27 +298,73 @@ def state_action_space(self, regime_params: FlatRegimeParams) -> StateActionSpac
294298
shock_kw[p] = cast("float", all_params[f"{name}__{p}"])
295299
state_replacements[name] = spec.compute_gridpoints(**shock_kw)
296300

297-
if not state_replacements and not action_replacements:
298-
return self._base_state_action_space
299-
300301
new_states = (
301302
MappingProxyType(
302303
dict(self._base_state_action_space.states) | state_replacements
303304
)
304305
if state_replacements
305-
else None
306+
else dict(self._base_state_action_space.states)
306307
)
307308
new_continuous_actions = (
308309
MappingProxyType(
309310
dict(self._base_state_action_space.continuous_actions)
310311
| action_replacements
311312
)
312313
if action_replacements
313-
else None
314+
else dict(self._base_state_action_space.continuous_actions)
314315
)
316+
317+
avail_devices = jax.devices()
318+
distributed_grids = {
319+
name: grid for name, grid in self.grids.items() if grid.distributed == True
320+
}
321+
if len(distributed_grids) == 1:
322+
n_points = distributed_grids[list(distributed_grids)[0]].to_jax().shape[0]
323+
state_name = list(distributed_grids)[0]
324+
if n_points % len(avail_devices) == 0:
325+
mesh = jax.make_mesh(
326+
(len(avail_devices),),
327+
("X"),
328+
axis_types=(jax.sharding.AxisType.Auto),
329+
devices=avail_devices,
330+
)
331+
new_states[state_name] = jax.device_put(
332+
new_states[state_name],
333+
jax.NamedSharding(mesh=mesh, spec=jax.P("X")),
334+
)
335+
else:
336+
raise PyLCMError(
337+
"When distributing over one grid, the number of points in the grid "
338+
"needs to be a multiple of the available devices. Gridpoints: "
339+
f" {n_points} Available Devices: {len(avail_devices)}"
340+
)
341+
if len(distributed_grids) > 1:
342+
permutations = reduce(
343+
mul, [grid.to_jax().shape[0] for grid in distributed_grids.values()]
344+
)
345+
if permutations == len(avail_devices):
346+
mesh = jax.make_mesh(
347+
tuple(len(grid.to_jax()) for grid in distributed_grids.values()),
348+
tuple(distributed_grids.keys()),
349+
axis_types=tuple(
350+
jax.sharding.AxisType.Auto for grid in distributed_grids
351+
),
352+
devices=avail_devices,
353+
)
354+
for state_name in distributed_grids:
355+
new_states[state_name] = jax.device_put(
356+
new_states[state_name],
357+
jax.NamedSharding(mesh=mesh, spec=jax.P(state_name)),
358+
)
359+
else:
360+
raise PyLCMError(
361+
"When distributing over multiple grids, the product of the number of"
362+
" points of the grids needs to match the number of available devices."
363+
f" Gridpoints: {permutations} Available Devices: {len(avail_devices)}"
364+
)
315365
return self._base_state_action_space.replace(
316-
states=new_states,
317-
continuous_actions=new_continuous_actions,
366+
states=MappingProxyType(new_states),
367+
continuous_actions=MappingProxyType(new_continuous_actions),
318368
)
319369

320370

src/lcm/solution/solve_brute.py

Lines changed: 27 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -49,13 +49,14 @@ def solve(
4949
# Compute V array shapes and build a consistent next_regime_to_V_arr
5050
# template. Using the same pytree structure (keys and shapes) across
5151
# all periods avoids JIT re-compilation from pytree mismatches.
52-
regime_V_shapes = _get_regime_V_shapes(
52+
regime_V_shapes = _get_regime_V_shapes_and_shardings(
5353
internal_regimes=internal_regimes,
5454
internal_params=internal_params,
5555
)
56+
5657
next_regime_to_V_arr = MappingProxyType(
5758
{
58-
regime_name: jnp.zeros(shape)
59+
regime_name: jax.device_put(jnp.zeros(shape))
5960
for regime_name, shape in regime_V_shapes.items()
6061
}
6162
)
@@ -146,7 +147,6 @@ def solve(
146147
period=jnp.int32(period),
147148
age=ages.values[period],
148149
)
149-
150150
# Async reductions: gated on log level. `"off"` skips
151151
# everything — no kernel launches, no host syncs, no
152152
# NaN fail-fast. `"warning"` / `"progress"` folds two
@@ -351,9 +351,7 @@ def _compile_and_log(
351351
compiled[func_id] = comp
352352

353353
# Map back to (regime, period) keys.
354-
return {
355-
key: compiled[_func_dedup_key(func=func)] for key, func in all_functions.items()
356-
}
354+
return {key: func for key, func in all_functions.items()}
357355

358356

359357
def _resolve_compilation_workers(*, max_compilation_workers: int | None) -> int:
@@ -386,7 +384,7 @@ def _func_dedup_key(*, func: Callable) -> Hashable:
386384
return id(func)
387385

388386

389-
def _get_regime_V_shapes(
387+
def _get_regime_V_shapes_and_shardings(
390388
*,
391389
internal_regimes: MappingProxyType[RegimeName, InternalRegime],
392390
internal_params: InternalParams,
@@ -404,13 +402,30 @@ def _get_regime_V_shapes(
404402
Dict of regime names to V array shapes.
405403
406404
"""
407-
shapes: dict[RegimeName, tuple[int, ...]] = {}
405+
shapes_and_shardings: dict[
406+
RegimeName, tuple[tuple[int, ...], jax.NamedSharding]
407+
] = {}
408+
avail_devices = jax.devices()
408409
for regime_name, regime in internal_regimes.items():
409410
state_action_space = regime.state_action_space(
410411
regime_params=internal_params[regime_name],
411412
)
412-
shapes[regime_name] = tuple(len(v) for v in state_action_space.states.values())
413-
return shapes
413+
spec = []
414+
for name in state_action_space.states:
415+
if regime.grids[name].distributed:
416+
spec.append("X")
417+
else:
418+
spec.append(None)
419+
shape = tuple(len(v) for v in state_action_space.states.values())
420+
mesh = jax.make_mesh(
421+
(len(avail_devices),),
422+
("X"),
423+
axis_types=(jax.sharding.AxisType.Auto),
424+
devices=avail_devices,
425+
)
426+
427+
shapes_and_shardings[regime_name] = shape
428+
return shapes_and_shardings
414429

415430

416431
@dataclass(frozen=True)
@@ -559,9 +574,9 @@ def _reconstruct_next_regime_to_V_arr(
559574
560575
We rebuild the same mapping post-hoc from `solution`. The shapes come from
561576
the regime's state-action space at the supplied params — identical to what
562-
`_get_regime_V_shapes` saw during solve setup.
577+
`_get_regime_V_shapes_and_shardings` saw during solve setup.
563578
"""
564-
regime_V_shapes = _get_regime_V_shapes(
579+
regime_V_shapes = _get_regime_V_shapes_and_shardings(
565580
internal_regimes=internal_regimes,
566581
internal_params=internal_params,
567582
)

tests/test_distributed.py

Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,81 @@
1+
from jax import numpy as jnp
2+
3+
from lcm.ages import AgeGrid
4+
from lcm.grids import categorical
5+
from lcm.grids.continuous import LinSpacedGrid
6+
from lcm.grids.discrete import DiscreteGrid
7+
from lcm.model import Model
8+
from lcm.regime import Regime
9+
10+
11+
def test_unused_state_raises_error():
12+
"""Model raises error when a state is defined but never used."""
13+
14+
@categorical(ordered=False)
15+
class RegimeId:
16+
working_life: int
17+
retirement: int
18+
19+
@categorical(ordered=True)
20+
class Type:
21+
low: int
22+
high: int
23+
24+
# Define a regime where 'unused_state' is not used in any function
25+
working_life = Regime(
26+
functions={
27+
"utility": lambda wealth, consumption, type1, type2: (
28+
(jnp.log(consumption) + wealth * 0.001) * type1 * type2
29+
),
30+
},
31+
states={
32+
"wealth": LinSpacedGrid(
33+
start=1,
34+
stop=100,
35+
n_points=10,
36+
),
37+
"type1": DiscreteGrid(Type, distributed=True),
38+
"type2": DiscreteGrid(Type, distributed=True),
39+
},
40+
state_transitions={
41+
"wealth": lambda wealth, consumption: wealth - consumption,
42+
"type1": None,
43+
"type2": None,
44+
},
45+
actions={"consumption": LinSpacedGrid(start=1, stop=50, n_points=10)},
46+
transition=lambda age: jnp.where(
47+
age >= 4, RegimeId.retirement, RegimeId.working_life
48+
),
49+
active=lambda age: age < 5,
50+
)
51+
52+
retirement = Regime(
53+
transition=None,
54+
functions={
55+
"utility": lambda wealth, type1, type2: (wealth * 0.5) * type1 * type2
56+
},
57+
states={
58+
"wealth": LinSpacedGrid(start=1, stop=100, n_points=10),
59+
"type1": DiscreteGrid(Type, distributed=True),
60+
"type2": DiscreteGrid(Type, distributed=True),
61+
},
62+
active=lambda age: age >= 5,
63+
)
64+
65+
model = Model(
66+
regimes={"working_life": working_life, "retirement": retirement},
67+
ages=AgeGrid(start=0, stop=5, step="Y"),
68+
regime_id_class=RegimeId,
69+
)
70+
res = model.simulate(
71+
params={"discount_factor": 0.95},
72+
initial_conditions={
73+
"age": jnp.full(5, 0),
74+
"wealth": jnp.full(5, 100.0),
75+
"type1": jnp.full(5, 1),
76+
"type2": jnp.full(5, 1),
77+
"regime": jnp.zeros(5, dtype=jnp.int32),
78+
},
79+
period_to_regime_to_V_arr=None,
80+
seed=12345,
81+
)

0 commit comments

Comments
 (0)