Skip to content

Commit 585f1d4

Browse files
Update
1 parent 9fe3104 commit 585f1d4

1 file changed

Lines changed: 2 additions & 10 deletions

File tree

src/maxtext/layers/nnx_decoders.py

Lines changed: 2 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -437,40 +437,32 @@ def _apply_layers_sequentially(self, layers, x_in, *args, length: int, **kwargs)
437437
prevent_cse = maxtext_utils.should_prevent_cse_in_remat(self.config)
438438
graphdef, params, state = nnx.split(
439439
layers, nnx.Param, ...
440-
) # state: the mutable state we carry (KV cache, RNGs, etc.)
440+
)
441441

442442
scan_axis = self.config.param_scan_axis
443443
if scan_axis != 0:
444-
# Move scan_axis to 0 so scan can iterate over it
445444
params = jax.tree.map(lambda x: jnp.moveaxis(x, scan_axis, 0), params)
446445

447446
layer_cls = layers.__class__
448447
sig = inspect.signature(layer_cls.__call__)
449448
valid_kwargs = {k: v for k, v in kwargs.items() if k in sig.parameters or "kwargs" in sig.parameters}
450449

451-
layer_cls = layers.__class__ # Access the underlying class
450+
layer_cls = layers.__class__
452451
sig = inspect.signature(layer_cls.__call__)
453-
# Filter kwargs to only include keys that exist in the layer's signature
454452
valid_kwargs = {k: v for k, v in kwargs.items() if k in sig.parameters or "kwargs" in sig.parameters}
455453

456454
def layer_fn(carry, scanned_vars):
457-
# Unpack the sliced variables for THIS layer
458455
current_params, current_state = scanned_vars
459456

460457
if self.config.parameter_memory_host_offload:
461458
current_params = jax.tree.map(lambda x: jax.device_put(x, max_utils.device_space()), current_params)
462459

463-
# Merge using the SLICED state
464460
layer = nnx.merge(graphdef, current_params, current_state)
465461

466-
# Run the layer (Filter kwargs if using the solution from previous turn)
467462
layer_out = layer(carry, *args, **valid_kwargs)
468463

469464
new_carry = layer_out[0] if isinstance(layer_out, tuple) else layer_out
470465

471-
# Extract the updated state to return it
472-
# _, new_current_state = nnx.split(layer, nnx.Param, ...)
473-
new_current_state = nnx.state(layer)
474466
return new_carry, new_current_state
475467

476468
layer_fn = jax.checkpoint(layer_fn, policy=policy, prevent_cse=prevent_cse)

0 commit comments

Comments
 (0)