Skip to content

Commit e3bd7e4

Browse files
committed
Add second distribution pattern
1 parent 2cab0af commit e3bd7e4

4 files changed

Lines changed: 67 additions & 64 deletions

File tree

src/lcm/grids/base.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@ def batch_size(self) -> int:
1616
1717
"""
1818

19-
2019
@property
2120
@abstractmethod
2221
def distributed(self) -> bool:

src/lcm/grids/discrete.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,9 @@ class DiscreteGrid(Grid):
1919
2020
"""
2121

22-
def __init__(self, category_class: type, batch_size: int = 0, distributed = False) -> None:
22+
def __init__(
23+
self, category_class: type, batch_size: int = 0, distributed=False
24+
) -> None:
2325
_validate_discrete_grid(category_class)
2426
names_and_values = get_field_names_and_values(category_class)
2527
self.__categories = tuple(names_and_values.keys())
@@ -47,7 +49,7 @@ def ordered(self) -> bool:
4749
def batch_size(self) -> int:
4850
"""Return batch size during solution."""
4951
return self.__batch_size
50-
52+
5153
@property
5254
def distributed(self) -> bool:
5355
"""Return batch size during solution."""

src/lcm/interfaces.py

Lines changed: 45 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,14 @@
11
import dataclasses
22
from collections.abc import Callable
3+
from functools import reduce
4+
from operator import mul
35
from types import MappingProxyType
46
from typing import cast
57

6-
from functools import reduce
7-
from operator import mul
8+
import jax
89
import pandas as pd
910
from jax import Array
10-
import jax
11+
1112
from lcm.exceptions import PyLCMError
1213
from lcm.grids import Grid, IrregSpacedGrid
1314
from lcm.shocks import _ShockGrid
@@ -314,65 +315,61 @@ def state_action_space(self, regime_params: FlatRegimeParams) -> StateActionSpac
314315
| action_replacements
315316
)
316317
if action_replacements
317-
else None
318+
else dict(self._base_state_action_space.continuous_actions)
318319
)
319-
320+
320321
avail_devices = jax.devices()
321-
distributed_grids = {name:grid for name,grid in self.grids.items() if grid.distributed == True}
322-
print(distributed_grids)
322+
distributed_grids = {
323+
name: grid for name, grid in self.grids.items() if grid.distributed == True
324+
}
323325
if len(distributed_grids) == 1:
324326
n_points = distributed_grids[list(distributed_grids)[0]].to_jax().shape[0]
325327
state_name = list(distributed_grids)[0]
326328
if n_points % len(avail_devices) == 0:
327-
mesh = jax.make_mesh((len(avail_devices),), ('X'), axis_types=(jax.sharding.AxisType.Auto),devices=avail_devices)
328-
new_states[state_name] = jax.device_put(new_states[state_name], jax.NamedSharding(mesh=mesh, spec=jax.P('X',)))
329+
mesh = jax.make_mesh(
330+
(len(avail_devices),),
331+
("X"),
332+
axis_types=(jax.sharding.AxisType.Auto),
333+
devices=avail_devices,
334+
)
335+
new_states[state_name] = jax.device_put(
336+
new_states[state_name],
337+
jax.NamedSharding(mesh=mesh, spec=jax.P("X")),
338+
)
329339
else:
330340
raise PyLCMError(
331-
"When distributing over one grid, the number of points in the grid "
332-
"needs to be a multiple of the available devices. Gridpoints: "
333-
f" {n_points} Available Devices: {len(avail_devices)}"
334-
)
341+
"When distributing over one grid, the number of points in the grid "
342+
"needs to be a multiple of the available devices. Gridpoints: "
343+
f" {n_points} Available Devices: {len(avail_devices)}"
344+
)
335345
if len(distributed_grids) > 1:
336-
permutations = reduce(mul, [grid.to_jax().shape[0] for grid in distributed_grids.values()])
337-
print(permutations)
346+
permutations = reduce(
347+
mul, [grid.to_jax().shape[0] for grid in distributed_grids.values()]
348+
)
338349
if permutations == len(avail_devices):
339-
device_orders = _partitioning_algo(list(distributed_grids.values()), avail_devices)
340-
print(device_orders)
341-
for i, (state_name, grid) in enumerate(distributed_grids.items()):
342-
mesh = jax.make_mesh((grid.to_jax().shape[0],), ('X'), devices=device_orders[i])
343-
new_states[state_name] = jax.device_put(new_states[state_name],jax.NamedSharding(mesh=mesh, spec=jax.P('X',)))
350+
mesh = jax.make_mesh(
351+
tuple(len(grid.to_jax()) for grid in distributed_grids.values()),
352+
tuple(distributed_grids.keys()),
353+
axis_types=tuple(
354+
jax.sharding.AxisType.Auto for grid in distributed_grids
355+
),
356+
devices=avail_devices,
357+
)
358+
for state_name in distributed_grids:
359+
new_states[state_name] = jax.device_put(
360+
new_states[state_name],
361+
jax.NamedSharding(mesh=mesh, spec=jax.P(state_name)),
362+
)
344363
else:
345364
raise PyLCMError(
346-
"When distributing over multiple grids, the product of the number of"
347-
" points of the grids needs to match the number of available devices."
348-
f" Gridpoints: {permutations} Available Devices: {len(avail_devices)}"
365+
"When distributing over multiple grids, the product of the number of"
366+
" points of the grids needs to match the number of available devices."
367+
f" Gridpoints: {permutations} Available Devices: {len(avail_devices)}"
349368
)
350369
return self._base_state_action_space.replace(
351-
states=MappingProxyType(new_states),
352-
continuous_actions=MappingProxyType(new_continuous_actions)
353-
)
354-
355-
def _partitioning_algo(grids: list[Grid], devices: list):
356-
number_devices = len(devices)
357-
print(len(grids[0].to_jax()))
358-
first_groups = [[] for i in range(len(grids[0].to_jax()))]
359-
for i in range(grids[0].to_jax().shape[0]):
360-
for j in range(number_devices//len(grids[0].to_jax())):
361-
first_groups[i].append(devices[j+number_devices//grids[0].to_jax().shape[0]])
362-
device_orders = [sum(first_groups, [])]
363-
last_groups = []
364-
for grid in grids[1:]:
365-
n_points = grid.to_jax().shape[0]
366-
next_groups = [[] for i in range(n_points)]
367-
for group in last_groups:
368-
for i in range(n_points):
369-
for j in range(len(group)/n_points):
370-
next_groups[i].append(devices[j+number_devices/n_points])
371-
device_orders.append(sum(next_groups, []))
372-
last_groups = next_groups
373-
return device_orders
374-
375-
370+
states=MappingProxyType(new_states),
371+
continuous_actions=MappingProxyType(new_continuous_actions),
372+
)
376373

377374

378375
@dataclasses.dataclass(frozen=True)

src/lcm/solution/solve_brute.py

Lines changed: 18 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -56,8 +56,8 @@ def solve(
5656

5757
next_regime_to_V_arr = MappingProxyType(
5858
{
59-
regime_name: jax.device_put(jnp.zeros(shape), device=sharding)
60-
for regime_name, (shape, sharding) in regime_V_shapes.items()
59+
regime_name: jax.device_put(jnp.zeros(shape))
60+
for regime_name, shape in regime_V_shapes.items()
6161
}
6262
)
6363

@@ -71,7 +71,7 @@ def solve(
7171
max_compilation_workers=max_compilation_workers,
7272
logger=logger,
7373
)
74-
74+
7575
solution: dict[int, MappingProxyType[RegimeName, FloatND]] = {}
7676

7777
# Async diagnostics accumulators: every `jnp.any(isnan)`,
@@ -134,7 +134,6 @@ def solve(
134134
period=jnp.int32(period),
135135
age=ages.values[period],
136136
)
137-
138137
# Async reductions: gated on log level. `"off"` skips
139138
# everything — no kernel launches, no host syncs, no
140139
# NaN fail-fast. `"warning"` / `"progress"` launches the
@@ -325,9 +324,7 @@ def _compile_and_log(
325324
compiled[func_id] = comp
326325

327326
# Map back to (regime, period) keys.
328-
return {
329-
key: compiled[_func_dedup_key(func=func)] for key, func in all_functions.items()
330-
}
327+
return {key: func for key, func in all_functions.items()}
331328

332329

333330
def _resolve_compilation_workers(*, max_compilation_workers: int | None) -> int:
@@ -378,24 +375,32 @@ def _get_regime_V_shapes_and_shardings(
378375
Dict of regime names to V array shapes.
379376
380377
"""
381-
shapes_and_shardings: dict[RegimeName, tuple[tuple[int, ...], jax.NamedSharding]] = {}
382-
avail_devices = jax.devices()
378+
shapes_and_shardings: dict[
379+
RegimeName, tuple[tuple[int, ...], jax.NamedSharding]
380+
] = {}
381+
avail_devices = jax.devices()
383382
for regime_name, regime in internal_regimes.items():
384383
state_action_space = regime.state_action_space(
385384
regime_params=internal_params[regime_name],
386385
)
387386
spec = []
388387
for name in state_action_space.states:
389388
if regime.grids[name].distributed:
390-
spec.append('X')
389+
spec.append("X")
391390
else:
392391
spec.append(None)
393392
shape = tuple(len(v) for v in state_action_space.states.values())
394-
mesh = jax.make_mesh((len(avail_devices),), ('X'), axis_types=(jax.sharding.AxisType.Auto),devices=avail_devices)
395-
sharding = jax.NamedSharding(mesh, spec= jax.P(*spec))
396-
shapes_and_shardings[regime_name] = (shape, sharding)
393+
mesh = jax.make_mesh(
394+
(len(avail_devices),),
395+
("X"),
396+
axis_types=(jax.sharding.AxisType.Auto),
397+
devices=avail_devices,
398+
)
399+
400+
shapes_and_shardings[regime_name] = shape
397401
return shapes_and_shardings
398402

403+
399404
@dataclass(frozen=True)
400405
class _DiagnosticRow:
401406
"""Metadata captured during the backward-induction loop.

0 commit comments

Comments
 (0)