@@ -437,40 +437,32 @@ def _apply_layers_sequentially(self, layers, x_in, *args, length: int, **kwargs)
437437 prevent_cse = maxtext_utils .should_prevent_cse_in_remat (self .config )
438438 graphdef , params , state = nnx .split (
439439 layers , nnx .Param , ...
440- ) # state: the mutable state we carry (KV cache, RNGs, etc.)
440+ )
441441
442442 scan_axis = self .config .param_scan_axis
443443 if scan_axis != 0 :
444- # Move scan_axis to 0 so scan can iterate over it
445444 params = jax .tree .map (lambda x : jnp .moveaxis (x , scan_axis , 0 ), params )
446445
447446 layer_cls = layers .__class__
448447 sig = inspect .signature (layer_cls .__call__ )
449448 valid_kwargs = {k : v for k , v in kwargs .items () if k in sig .parameters or "kwargs" in sig .parameters }
450449
451- layer_cls = layers .__class__ # Access the underlying class
450+ layer_cls = layers .__class__
452451 sig = inspect .signature (layer_cls .__call__ )
453- # Filter kwargs to only include keys that exist in the layer's signature
454452 valid_kwargs = {k : v for k , v in kwargs .items () if k in sig .parameters or "kwargs" in sig .parameters }
455453
456454 def layer_fn (carry , scanned_vars ):
457- # Unpack the sliced variables for THIS layer
458455 current_params , current_state = scanned_vars
459456
460457 if self .config .parameter_memory_host_offload :
461458 current_params = jax .tree .map (lambda x : jax .device_put (x , max_utils .device_space ()), current_params )
462459
463- # Merge using the SLICED state
464460 layer = nnx .merge (graphdef , current_params , current_state )
465461
466- # Run the layer (Filter kwargs if using the solution from previous turn)
467462 layer_out = layer (carry , * args , ** valid_kwargs )
468463
469464 new_carry = layer_out [0 ] if isinstance (layer_out , tuple ) else layer_out
470465
471- # Extract the updated state to return it
472- # _, new_current_state = nnx.split(layer, nnx.Param, ...)
473- new_current_state = nnx .state (layer )
474466 return new_carry , new_current_state
475467
476468 layer_fn = jax .checkpoint (layer_fn , policy = policy , prevent_cse = prevent_cse )
0 commit comments