Add Paralellization across multiple devices during Solve#346
Open
Add Paralellization across multiple devices during Solve#346
Conversation
Benchmark comparison (main → HEAD)Comparing
|
hmgaudecker
pushed a commit
that referenced
this pull request
May 9, 2026
Squash of `distributed` (mj023). Adds a `distributed=True` flag on `DiscreteGrid` to shard the grid across JAX devices, threads the distribution pattern through `solve_brute._get_regime_V_shapes_and_shardings`, and validates the device-count match at runtime via a new check in `InternalRegime.state_action_space`. Rebased on top of `feat/canonical-float-dtype` so the work picks up the dtype-barrier and simulate-AOT changes. Also retargets the second caller of the renamed shapes helper (`_reconstruct_next_regime_to_V_arr`) at the new name. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Add a `distributed=True` flag on `DiscreteGrid` to shard the grid across JAX devices, thread the distribution pattern through `solve_brute._get_regime_V_shapes_and_shardings`, and validate the device-count match at runtime via a new check in `InternalRegime.state_action_space`. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
This PR is a continuation of #147. It uses JAX Auto Parallelization inside jitted functions to make it possible to split the state space across multiple devices and then let every device solve it's part independently.
Distribution strategy
The grids get a new argument
distributedthat the user can use to specify which grids should be considered for the distribution across devices. If only one grid is marked for distribution, then the length of the grid needs to be a multiple of the available devices, if multiple grids are marked, the product of the lengths needs to be exactly the number of available devices (might be possible to relax this requirement). The grids then need to be moved to the right devices after they have been initialized with runtime-supplied points and shocks. The resultingVF_arrwill then automatically be split on the devices where each value has been calculated.TODO