|
1 | 1 | import dataclasses |
2 | 2 | from collections.abc import Callable |
| 3 | +from functools import reduce |
| 4 | +from operator import mul |
3 | 5 | from types import MappingProxyType |
4 | 6 | from typing import cast |
5 | 7 |
|
6 | | -from functools import reduce |
7 | | -from operator import mul |
| 8 | +import jax |
8 | 9 | import pandas as pd |
9 | 10 | from jax import Array |
10 | | -import jax |
| 11 | + |
11 | 12 | from lcm.exceptions import PyLCMError |
12 | 13 | from lcm.grids import Grid, IrregSpacedGrid |
13 | 14 | from lcm.shocks import _ShockGrid |
@@ -314,65 +315,61 @@ def state_action_space(self, regime_params: FlatRegimeParams) -> StateActionSpac |
314 | 315 | | action_replacements |
315 | 316 | ) |
316 | 317 | if action_replacements |
317 | | - else None |
| 318 | + else dict(self._base_state_action_space.continuous_actions) |
318 | 319 | ) |
319 | | - |
| 320 | + |
320 | 321 | avail_devices = jax.devices() |
321 | | - distributed_grids = {name:grid for name,grid in self.grids.items() if grid.distributed == True} |
322 | | - print(distributed_grids) |
| 322 | + distributed_grids = { |
| 323 | + name: grid for name, grid in self.grids.items() if grid.distributed == True |
| 324 | + } |
323 | 325 | if len(distributed_grids) == 1: |
324 | 326 | n_points = distributed_grids[list(distributed_grids)[0]].to_jax().shape[0] |
325 | 327 | state_name = list(distributed_grids)[0] |
326 | 328 | if n_points % len(avail_devices) == 0: |
327 | | - mesh = jax.make_mesh((len(avail_devices),), ('X'), axis_types=(jax.sharding.AxisType.Auto),devices=avail_devices) |
328 | | - new_states[state_name] = jax.device_put(new_states[state_name], jax.NamedSharding(mesh=mesh, spec=jax.P('X',))) |
| 329 | + mesh = jax.make_mesh( |
| 330 | + (len(avail_devices),), |
| 331 | + ("X"), |
| 332 | + axis_types=(jax.sharding.AxisType.Auto), |
| 333 | + devices=avail_devices, |
| 334 | + ) |
| 335 | + new_states[state_name] = jax.device_put( |
| 336 | + new_states[state_name], |
| 337 | + jax.NamedSharding(mesh=mesh, spec=jax.P("X")), |
| 338 | + ) |
329 | 339 | else: |
330 | 340 | raise PyLCMError( |
331 | | - "When distributing over one grid, the number of points in the grid " |
332 | | - "needs to be a multiple of the available devices. Gridpoints: " |
333 | | - f" {n_points} Available Devices: {len(avail_devices)}" |
334 | | - ) |
| 341 | + "When distributing over one grid, the number of points in the grid " |
| 342 | + "needs to be a multiple of the available devices. Gridpoints: " |
| 343 | + f" {n_points} Available Devices: {len(avail_devices)}" |
| 344 | + ) |
335 | 345 | if len(distributed_grids) > 1: |
336 | | - permutations = reduce(mul, [grid.to_jax().shape[0] for grid in distributed_grids.values()]) |
337 | | - print(permutations) |
| 346 | + permutations = reduce( |
| 347 | + mul, [grid.to_jax().shape[0] for grid in distributed_grids.values()] |
| 348 | + ) |
338 | 349 | if permutations == len(avail_devices): |
339 | | - device_orders = _partitioning_algo(list(distributed_grids.values()), avail_devices) |
340 | | - print(device_orders) |
341 | | - for i, (state_name, grid) in enumerate(distributed_grids.items()): |
342 | | - mesh = jax.make_mesh((grid.to_jax().shape[0],), ('X'), devices=device_orders[i]) |
343 | | - new_states[state_name] = jax.device_put(new_states[state_name],jax.NamedSharding(mesh=mesh, spec=jax.P('X',))) |
| 350 | + mesh = jax.make_mesh( |
| 351 | + tuple(len(grid.to_jax()) for grid in distributed_grids.values()), |
| 352 | + tuple(distributed_grids.keys()), |
| 353 | + axis_types=tuple( |
| 354 | + jax.sharding.AxisType.Auto for grid in distributed_grids |
| 355 | + ), |
| 356 | + devices=avail_devices, |
| 357 | + ) |
| 358 | + for state_name in distributed_grids: |
| 359 | + new_states[state_name] = jax.device_put( |
| 360 | + new_states[state_name], |
| 361 | + jax.NamedSharding(mesh=mesh, spec=jax.P(state_name)), |
| 362 | + ) |
344 | 363 | else: |
345 | 364 | raise PyLCMError( |
346 | | - "When distributing over multiple grids, the product of the number of" |
347 | | - " points of the grids needs to match the number of available devices." |
348 | | - f" Gridpoints: {permutations} Available Devices: {len(avail_devices)}" |
| 365 | + "When distributing over multiple grids, the product of the number of" |
| 366 | + " points of the grids needs to match the number of available devices." |
| 367 | + f" Gridpoints: {permutations} Available Devices: {len(avail_devices)}" |
349 | 368 | ) |
350 | 369 | return self._base_state_action_space.replace( |
351 | | - states=MappingProxyType(new_states), |
352 | | - continuous_actions=MappingProxyType(new_continuous_actions) |
353 | | - ) |
354 | | - |
355 | | -def _partitioning_algo(grids: list[Grid], devices: list): |
356 | | - number_devices = len(devices) |
357 | | - print(len(grids[0].to_jax())) |
358 | | - first_groups = [[] for i in range(len(grids[0].to_jax()))] |
359 | | - for i in range(grids[0].to_jax().shape[0]): |
360 | | - for j in range(number_devices//len(grids[0].to_jax())): |
361 | | - first_groups[i].append(devices[j+number_devices//grids[0].to_jax().shape[0]]) |
362 | | - device_orders = [sum(first_groups, [])] |
363 | | - last_groups = [] |
364 | | - for grid in grids[1:]: |
365 | | - n_points = grid.to_jax().shape[0] |
366 | | - next_groups = [[] for i in range(n_points)] |
367 | | - for group in last_groups: |
368 | | - for i in range(n_points): |
369 | | - for j in range(len(group)/n_points): |
370 | | - next_groups[i].append(devices[j+number_devices/n_points]) |
371 | | - device_orders.append(sum(next_groups, [])) |
372 | | - last_groups = next_groups |
373 | | - return device_orders |
374 | | - |
375 | | - |
| 370 | + states=MappingProxyType(new_states), |
| 371 | + continuous_actions=MappingProxyType(new_continuous_actions), |
| 372 | + ) |
376 | 373 |
|
377 | 374 |
|
378 | 375 | @dataclasses.dataclass(frozen=True) |
|
0 commit comments