@@ -394,33 +394,61 @@ def _create_single_layer(self, decoder_layer_class, rngs, **kwargs):
394394 return nnx_wrappers .ToNNX (layer_linen , rngs = rngs )
395395
396396 def _create_scanned_layers (self , decoder_layer_class , length : int , metadata_axis_name : str , rngs : nnx .Rngs , ** layer_kwargs ):
397- """Creates a VMapped stack of layers, forcing parameter init for Compact modules."""
397+ """Creates a scanned stack of layers using jax.lax.scan for memory-efficient sequential initialization.
398398
399- def create_layer_fn (rng ):
399+ Uses jax.lax.scan instead of nnx.vmap to reduce peak memory during initialization.
400+ With vmap, all layers' parameters are created simultaneously (O(N) peak memory).
401+ With scan, parameters are created one layer at a time (O(1) peak intermediate memory),
402+ which prevents OOM on memory-constrained devices like TPU v6e-4.
403+ """
404+ scan_axis = self .config .param_scan_axis
405+
406+ # Split rngs to get per-layer RNG states
407+ split_rngs = rngs .split (length )
408+ rngs_graphdef , rngs_state = nnx .split (split_rngs )
409+
410+ # Create a reference layer to capture the module graph structure (graphdef).
411+ # This layer's params are discarded — only the structure is kept.
412+ ref_rngs = nnx .Rngs (0 )
413+ ref_layer = decoder_layer_class (
414+ config = self .config , mesh = self .mesh , quant = self .quant ,
415+ model_mode = self .model_mode , rngs = ref_rngs , ** layer_kwargs
416+ )
417+ layer_graphdef , _ , _ = nnx .split (ref_layer , nnx .Param , ...)
418+
419+ # Sequentially create each layer's parameters via jax.lax.scan.
420+ # The scan body is traced once; XLA executes it N times with different RNG keys,
421+ # keeping only one layer's intermediate state alive at a time.
422+ def scan_body (carry , rng_state_slice ):
423+ layer_rngs = nnx .merge (rngs_graphdef , rng_state_slice )
400424 layer = decoder_layer_class (
401- config = self .config , mesh = self .mesh , quant = self .quant , model_mode = self .model_mode , rngs = rng , ** layer_kwargs
425+ config = self .config , mesh = self .mesh , quant = self .quant ,
426+ model_mode = self .model_mode , rngs = layer_rngs , ** layer_kwargs
402427 )
403- return nnx .split (layer , nnx .Param , ...)
404- # return layer
405-
406- try :
407- forked_rngs = rngs .fork (split = length )
408- except : # pylint: disable=bare-except
409- pass
410-
411- graphdef , params , rest = nnx .vmap (
412- create_layer_fn ,
413- in_axes = 0 ,
414- out_axes = (None , self .config .param_scan_axis , 0 ),
415- axis_name = metadata_axis_name ,
416- transform_metadata = {
417- nnx .PARTITION_NAME : metadata_axis_name ,
418- "param_scan_axis" : self .config .param_scan_axis ,
419- },
420- )(forked_rngs )
421- layers_vmapped = nnx .merge (graphdef , params , rest )
428+ _ , params , rest = nnx .split (layer , nnx .Param , ...)
429+ return carry , (params , rest )
422430
423- return layers_vmapped
431+ _ , (stacked_params , stacked_rest ) = jax .lax .scan (scan_body , None , rngs_state )
432+
433+ # jax.lax.scan stacks outputs along axis 0. Move params to the configured scan axis.
434+ if scan_axis != 0 :
435+ stacked_params = jax .tree .map (lambda x : jnp .moveaxis (x , 0 , scan_axis ), stacked_params )
436+
437+ # Add partition metadata that nnx.vmap's transform_metadata would normally set.
438+ # This metadata is read by variable_to_logically_partitioned() in initializers.py
439+ # to insert the scan axis name into logical sharding specs.
440+ def _add_partition_metadata (state ):
441+ def _update (vs ):
442+ if isinstance (vs , nnx .Variable ):
443+ metadata = vs .get_metadata ()
444+ return type (vs )(vs .get_value (), ** {** metadata , nnx .PARTITION_NAME : metadata_axis_name , "param_scan_axis" : scan_axis })
445+ return vs
446+ return jax .tree .map (_update , state , is_leaf = lambda x : isinstance (x , nnx .Variable ))
447+
448+ stacked_params = _add_partition_metadata (stacked_params )
449+ stacked_rest = _add_partition_metadata (stacked_rest )
450+
451+ return nnx .merge (layer_graphdef , stacked_params , stacked_rest )
424452
425453 def _apply_layer_with_remat (self , layer : nnx .Module , y : jax .Array , policy : Any , prevent_cse : bool , ** kwargs ):
426454 """Helper to cleanly apply jax.checkpoint to a single unscanned layer or block."""
0 commit comments