@@ -49,13 +49,14 @@ def solve(
4949 # Compute V array shapes and build a consistent next_regime_to_V_arr
5050 # template. Using the same pytree structure (keys and shapes) across
5151 # all periods avoids JIT re-compilation from pytree mismatches.
52- regime_V_shapes = _get_regime_V_shapes (
52+ regime_V_shapes = _get_regime_V_shapes_and_shardings (
5353 internal_regimes = internal_regimes ,
5454 internal_params = internal_params ,
5555 )
56+
5657 next_regime_to_V_arr = MappingProxyType (
5758 {
58- regime_name : jnp .zeros (shape )
59+ regime_name : jax . device_put ( jnp .zeros (shape ) )
5960 for regime_name , shape in regime_V_shapes .items ()
6061 }
6162 )
@@ -146,7 +147,6 @@ def solve(
146147 period = jnp .int32 (period ),
147148 age = ages .values [period ],
148149 )
149-
150150 # Async reductions: gated on log level. `"off"` skips
151151 # everything — no kernel launches, no host syncs, no
152152 # NaN fail-fast. `"warning"` / `"progress"` folds two
@@ -351,9 +351,7 @@ def _compile_and_log(
351351 compiled [func_id ] = comp
352352
353353 # Map back to (regime, period) keys.
354- return {
355- key : compiled [_func_dedup_key (func = func )] for key , func in all_functions .items ()
356- }
354+ return {key : func for key , func in all_functions .items ()}
357355
358356
359357def _resolve_compilation_workers (* , max_compilation_workers : int | None ) -> int :
@@ -386,7 +384,7 @@ def _func_dedup_key(*, func: Callable) -> Hashable:
386384 return id (func )
387385
388386
389- def _get_regime_V_shapes (
387+ def _get_regime_V_shapes_and_shardings (
390388 * ,
391389 internal_regimes : MappingProxyType [RegimeName , InternalRegime ],
392390 internal_params : InternalParams ,
@@ -404,13 +402,30 @@ def _get_regime_V_shapes(
404402 Dict of regime names to V array shapes.
405403
406404 """
407- shapes : dict [RegimeName , tuple [int , ...]] = {}
405+ shapes_and_shardings : dict [
406+ RegimeName , tuple [tuple [int , ...], jax .NamedSharding ]
407+ ] = {}
408+ avail_devices = jax .devices ()
408409 for regime_name , regime in internal_regimes .items ():
409410 state_action_space = regime .state_action_space (
410411 regime_params = internal_params [regime_name ],
411412 )
412- shapes [regime_name ] = tuple (len (v ) for v in state_action_space .states .values ())
413- return shapes
413+ spec = []
414+ for name in state_action_space .states :
415+ if regime .grids [name ].distributed :
416+ spec .append ("X" )
417+ else :
418+ spec .append (None )
419+ shape = tuple (len (v ) for v in state_action_space .states .values ())
420+ mesh = jax .make_mesh (
421+ (len (avail_devices ),),
422+ ("X" ),
423+ axis_types = (jax .sharding .AxisType .Auto ),
424+ devices = avail_devices ,
425+ )
426+
427+ shapes_and_shardings [regime_name ] = shape
428+ return shapes_and_shardings
414429
415430
416431@dataclass (frozen = True )
@@ -559,9 +574,9 @@ def _reconstruct_next_regime_to_V_arr(
559574
560575 We rebuild the same mapping post-hoc from `solution`. The shapes come from
561576 the regime's state-action space at the supplied params — identical to what
562- `_get_regime_V_shapes ` saw during solve setup.
577+ `_get_regime_V_shapes_and_shardings ` saw during solve setup.
563578 """
564- regime_V_shapes = _get_regime_V_shapes (
579+ regime_V_shapes = _get_regime_V_shapes_and_shardings (
565580 internal_regimes = internal_regimes ,
566581 internal_params = internal_params ,
567582 )
0 commit comments