Skip to content

Add Paralellization across multiple devices during Solve#346

Open
mj023 wants to merge 6 commits intomainfrom
distributed
Open

Add Paralellization across multiple devices during Solve#346
mj023 wants to merge 6 commits intomainfrom
distributed

Conversation

@mj023
Copy link
Copy Markdown
Collaborator

@mj023 mj023 commented May 8, 2026

This PR is a continuation of #147. It uses JAX Auto Parallelization inside jitted functions to make it possible to split the state space across multiple devices and then let every device solve it's part independently.

Distribution strategy

The grids get a new argument distributed that the user can use to specify which grids should be considered for the distribution across devices. If only one grid is marked for distribution, then the length of the grid needs to be a multiple of the available devices, if multiple grids are marked, the product of the lengths needs to be exactly the number of available devices (might be possible to relax this requirement). The grids then need to be moved to the right devices after they have been initialized with runtime-supplied points and shocks. The resulting VF_arr will then automatically be split on the devices where each value has been calculated.

TODO

  • Move grids to right devices after state action space creation
  • Fix AOT-Compilation
  • Test communication overhead

@mj023 mj023 mentioned this pull request May 8, 2026
@read-the-docs-community
Copy link
Copy Markdown

read-the-docs-community Bot commented May 8, 2026

@github-actions
Copy link
Copy Markdown

github-actions Bot commented May 8, 2026

Benchmark comparison (main → HEAD)

Comparing 99a5e31d (main) → 7eeaf370 (HEAD)

Benchmark Statistic before after Ratio Alert
aca-baseline execution time 27.474 s 26.771 s 0.97
peak GPU mem 509 MB 847 MB 1.66
compilation time 299.51 s 301.27 s 1.01
peak CPU mem 7.65 GB 7.53 GB 0.98
Mahler-Yum execution time 4.712 s 4.742 s 1.01
peak GPU mem 529 MB 529 MB 1.00
compilation time 14.59 s 16.96 s 1.16
peak CPU mem 1.68 GB 1.73 GB 1.03
Precautionary Savings - Solve execution time 50.8 ms 49.9 ms 0.98
peak GPU mem 101 MB 101 MB 1.00
compilation time 2.71 s 2.80 s 1.03
peak CPU mem 1.13 GB 1.13 GB 1.00
Precautionary Savings - Simulate execution time 126.7 ms 127.0 ms 1.00
peak GPU mem 344 MB 344 MB 1.00
compilation time 4.90 s 7.11 s 1.45
peak CPU mem 1.31 GB 1.32 GB 1.01
Precautionary Savings - Solve & Simulate execution time 145.2 ms 152.9 ms 1.05
peak GPU mem 578 MB 578 MB 1.00
compilation time 7.02 s 9.00 s 1.28
peak CPU mem 1.28 GB 1.31 GB 1.03
Precautionary Savings - Solve & Simulate (irreg) execution time 283.3 ms 295.9 ms 1.04
peak GPU mem 2.19 GB 2.19 GB 1.00
compilation time 7.58 s 9.86 s 1.30
peak CPU mem 1.34 GB 1.36 GB 1.02

hmgaudecker pushed a commit that referenced this pull request May 9, 2026
Squash of `distributed` (mj023). Adds a `distributed=True` flag on
`DiscreteGrid` to shard the grid across JAX devices, threads the
distribution pattern through `solve_brute._get_regime_V_shapes_and_shardings`,
and validates the device-count match at runtime via a new check in
`InternalRegime.state_action_space`.

Rebased on top of `feat/canonical-float-dtype` so the work picks up the
dtype-barrier and simulate-AOT changes. Also retargets the second caller
of the renamed shapes helper (`_reconstruct_next_regime_to_V_arr`) at
the new name.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
@hmgaudecker hmgaudecker changed the base branch from main to feat/canonical-float-dtype May 9, 2026 11:42
Base automatically changed from feat/canonical-float-dtype to main May 11, 2026 07:14
Add a `distributed=True` flag on `DiscreteGrid` to shard the grid
across JAX devices, thread the distribution pattern through
`solve_brute._get_regime_V_shapes_and_shardings`, and validate the
device-count match at runtime via a new check in
`InternalRegime.state_action_space`.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants