Skip to content

Commit 74ab249

Browse files
Update
1 parent dc403a0 commit 74ab249

1 file changed

Lines changed: 42 additions & 11 deletions

File tree

src/maxtext/layers/nnx_decoders.py

Lines changed: 42 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -463,19 +463,51 @@ def _update_leaf(leaf):
463463

464464
def _apply_layer_with_remat(self, layer: nnx.Module, y: jax.Array, policy: Any, prevent_cse: bool, **kwargs):
465465
"""Helper to cleanly apply jax.checkpoint to a single unscanned layer or block."""
466+
"""Runs the layer stack using nnx.scan."""
467+
policy = self.get_remat_policy()
468+
prevent_cse = maxtext_utils.should_prevent_cse_in_remat(self.config)
469+
graphdef, params, state = nnx.split(layers, nnx.Param, ...)
466470

467-
graphdef, state = nnx.split(layer)
471+
scan_axis = self.config.param_scan_axis
472+
if scan_axis != 0:
473+
params = jax.tree.map(lambda x: jnp.moveaxis(x, scan_axis, 0), params)
474+
475+
layer_cls = layers.__class__
476+
sig = inspect.signature(layer_cls.__call__)
477+
valid_kwargs = {k: v for k, v in kwargs.items() if k in sig.parameters or "kwargs" in sig.parameters}
468478

469-
def pure_layer_fn(state_in, y_in):
470-
merged_layer = nnx.merge(graphdef, state_in)
471-
out = merged_layer(y_in, **kwargs)
472-
return out, nnx.state(merged_layer)
479+
def _extract_matching_state(template, full):
480+
if isinstance(template, nnx.State):
481+
return nnx.State({k: _extract_matching_state(v, full[k]) for k, v in template.items()})
482+
elif isinstance(template, dict):
483+
return {k: _extract_matching_state(v, full[k]) for k, v in template.items()}
484+
return full
473485

474-
checkpointed_fn = jax.checkpoint(pure_layer_fn, policy=policy, prevent_cse=prevent_cse)
475-
out, new_state = checkpointed_fn(state, y)
476-
nnx.update(layer, new_state)
486+
def layer_fn(carry, scanned_vars):
487+
current_params, current_state = scanned_vars
477488

478-
return out
489+
if self.config.parameter_memory_host_offload:
490+
current_params = jax.tree.map(lambda x: jax.device_put(x, max_utils.device_space()), current_params)
491+
492+
layer = nnx.merge(graphdef, current_params, current_state)
493+
layer_out = layer(carry, *args, **valid_kwargs)
494+
new_carry = layer_out[0] if isinstance(layer_out, tuple) else layer_out
495+
496+
new_full_state = nnx.state(layer)
497+
new_current_state = _extract_matching_state(current_state, new_full_state)
498+
499+
# ONLY return non-param state to prevent memory duplication of weights
500+
return new_carry, new_current_state
501+
502+
layer_fn = jax.checkpoint(layer_fn, policy=policy, prevent_cse=prevent_cse)
503+
504+
final_carry, scanned_other = jax.lax.scan(layer_fn, x_in, (params, state))
505+
506+
if scan_axis != 0:
507+
params = jax.tree.map(lambda x: jnp.moveaxis(x, 0, scan_axis), params)
508+
509+
scanned_state = nnx.State.merge(params, scanned_other)
510+
return final_carry, nnx.merge(graphdef, scanned_state)
479511

480512
def _apply_layers_sequentially(self, layers, x_in, *args, length: int, **kwargs):
481513
"""Runs the layer stack using nnx.scan."""
@@ -885,8 +917,6 @@ def _extract_slice(x, idx, axis):
885917

886918
# Run the single layer
887919
out = single_layer(
888-
y, *args, decoder_input_tokens=kwargs.get("decoder_input_tokens"), **kwargs.get("layer_kwargs", {})
889-
)
890920
y = out[0] if isinstance(out, tuple) else out
891921

892922
# Re-merge the updated state back into the specific slice of the stack
@@ -944,6 +974,7 @@ def _apply_scanned_chunk(self, y, current_idx, next_boundary, layer_stack, *args
944974

945975
return y
946976

977+
947978
def _apply_interleaved_scanned_layers(self, y, layer_stack, start_idx, end_idx, engram_indices, *args, **kwargs):
948979
"""Applies a mix of scanned standard layers and unscanned Engram layers."""
949980
current_idx = start_idx

0 commit comments

Comments
 (0)