@@ -463,19 +463,51 @@ def _update_leaf(leaf):
463463
464464 def _apply_layer_with_remat (self , layer : nnx .Module , y : jax .Array , policy : Any , prevent_cse : bool , ** kwargs ):
465465 """Helper to cleanly apply jax.checkpoint to a single unscanned layer or block."""
466+ """Runs the layer stack using nnx.scan."""
467+ policy = self .get_remat_policy ()
468+ prevent_cse = maxtext_utils .should_prevent_cse_in_remat (self .config )
469+ graphdef , params , state = nnx .split (layers , nnx .Param , ...)
466470
467- graphdef , state = nnx .split (layer )
471+ scan_axis = self .config .param_scan_axis
472+ if scan_axis != 0 :
473+ params = jax .tree .map (lambda x : jnp .moveaxis (x , scan_axis , 0 ), params )
474+
475+ layer_cls = layers .__class__
476+ sig = inspect .signature (layer_cls .__call__ )
477+ valid_kwargs = {k : v for k , v in kwargs .items () if k in sig .parameters or "kwargs" in sig .parameters }
468478
469- def pure_layer_fn (state_in , y_in ):
470- merged_layer = nnx .merge (graphdef , state_in )
471- out = merged_layer (y_in , ** kwargs )
472- return out , nnx .state (merged_layer )
479+ def _extract_matching_state (template , full ):
480+ if isinstance (template , nnx .State ):
481+ return nnx .State ({k : _extract_matching_state (v , full [k ]) for k , v in template .items ()})
482+ elif isinstance (template , dict ):
483+ return {k : _extract_matching_state (v , full [k ]) for k , v in template .items ()}
484+ return full
473485
474- checkpointed_fn = jax .checkpoint (pure_layer_fn , policy = policy , prevent_cse = prevent_cse )
475- out , new_state = checkpointed_fn (state , y )
476- nnx .update (layer , new_state )
486+ def layer_fn (carry , scanned_vars ):
487+ current_params , current_state = scanned_vars
477488
478- return out
489+ if self .config .parameter_memory_host_offload :
490+ current_params = jax .tree .map (lambda x : jax .device_put (x , max_utils .device_space ()), current_params )
491+
492+ layer = nnx .merge (graphdef , current_params , current_state )
493+ layer_out = layer (carry , * args , ** valid_kwargs )
494+ new_carry = layer_out [0 ] if isinstance (layer_out , tuple ) else layer_out
495+
496+ new_full_state = nnx .state (layer )
497+ new_current_state = _extract_matching_state (current_state , new_full_state )
498+
499+ # ONLY return non-param state to prevent memory duplication of weights
500+ return new_carry , new_current_state
501+
502+ layer_fn = jax .checkpoint (layer_fn , policy = policy , prevent_cse = prevent_cse )
503+
504+ final_carry , scanned_other = jax .lax .scan (layer_fn , x_in , (params , state ))
505+
506+ if scan_axis != 0 :
507+ params = jax .tree .map (lambda x : jnp .moveaxis (x , 0 , scan_axis ), params )
508+
509+ scanned_state = nnx .State .merge (params , scanned_other )
510+ return final_carry , nnx .merge (graphdef , scanned_state )
479511
480512 def _apply_layers_sequentially (self , layers , x_in , * args , length : int , ** kwargs ):
481513 """Runs the layer stack using nnx.scan."""
@@ -885,8 +917,6 @@ def _extract_slice(x, idx, axis):
885917
886918 # Run the single layer
887919 out = single_layer (
888- y , * args , decoder_input_tokens = kwargs .get ("decoder_input_tokens" ), ** kwargs .get ("layer_kwargs" , {})
889- )
890920 y = out [0 ] if isinstance (out , tuple ) else out
891921
892922 # Re-merge the updated state back into the specific slice of the stack
@@ -944,6 +974,7 @@ def _apply_scanned_chunk(self, y, current_idx, next_boundary, layer_stack, *args
944974
945975 return y
946976
977+
947978 def _apply_interleaved_scanned_layers (self , y , layer_stack , start_idx , end_idx , engram_indices , * args , ** kwargs ):
948979 """Applies a mix of scanned standard layers and unscanned Engram layers."""
949980 current_idx = start_idx
0 commit comments