Skip to content

Commit 86a5501

Browse files
committed
test mesa
1 parent b5432f4 commit 86a5501

1 file changed

Lines changed: 51 additions & 23 deletions

File tree

src/maxtext/layers/nnx_decoders.py

Lines changed: 51 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -394,33 +394,61 @@ def _create_single_layer(self, decoder_layer_class, rngs, **kwargs):
394394
return nnx_wrappers.ToNNX(layer_linen, rngs=rngs)
395395

396396
def _create_scanned_layers(self, decoder_layer_class, length: int, metadata_axis_name: str, rngs: nnx.Rngs, **layer_kwargs):
397-
"""Creates a VMapped stack of layers, forcing parameter init for Compact modules."""
397+
"""Creates a scanned stack of layers using jax.lax.scan for memory-efficient sequential initialization.
398398
399-
def create_layer_fn(rng):
399+
Uses jax.lax.scan instead of nnx.vmap to reduce peak memory during initialization.
400+
With vmap, all layers' parameters are created simultaneously (O(N) peak memory).
401+
With scan, parameters are created one layer at a time (O(1) peak intermediate memory),
402+
which prevents OOM on memory-constrained devices like TPU v6e-4.
403+
"""
404+
scan_axis = self.config.param_scan_axis
405+
406+
# Split rngs to get per-layer RNG states
407+
split_rngs = rngs.split(length)
408+
rngs_graphdef, rngs_state = nnx.split(split_rngs)
409+
410+
# Create a reference layer to capture the module graph structure (graphdef).
411+
# This layer's params are discarded — only the structure is kept.
412+
ref_rngs = nnx.Rngs(0)
413+
ref_layer = decoder_layer_class(
414+
config=self.config, mesh=self.mesh, quant=self.quant,
415+
model_mode=self.model_mode, rngs=ref_rngs, **layer_kwargs
416+
)
417+
layer_graphdef, _, _ = nnx.split(ref_layer, nnx.Param, ...)
418+
419+
# Sequentially create each layer's parameters via jax.lax.scan.
420+
# The scan body is traced once; XLA executes it N times with different RNG keys,
421+
# keeping only one layer's intermediate state alive at a time.
422+
def scan_body(carry, rng_state_slice):
423+
layer_rngs = nnx.merge(rngs_graphdef, rng_state_slice)
400424
layer = decoder_layer_class(
401-
config=self.config, mesh=self.mesh, quant=self.quant, model_mode=self.model_mode, rngs=rng, **layer_kwargs
425+
config=self.config, mesh=self.mesh, quant=self.quant,
426+
model_mode=self.model_mode, rngs=layer_rngs, **layer_kwargs
402427
)
403-
return nnx.split(layer, nnx.Param, ...)
404-
# return layer
405-
406-
try:
407-
forked_rngs = rngs.fork(split=length)
408-
except: # pylint: disable=bare-except
409-
pass
410-
411-
graphdef, params, rest = nnx.vmap(
412-
create_layer_fn,
413-
in_axes=0,
414-
out_axes=(None, self.config.param_scan_axis, 0),
415-
axis_name=metadata_axis_name,
416-
transform_metadata={
417-
nnx.PARTITION_NAME: metadata_axis_name,
418-
"param_scan_axis": self.config.param_scan_axis,
419-
},
420-
)(forked_rngs)
421-
layers_vmapped = nnx.merge(graphdef, params, rest)
428+
_, params, rest = nnx.split(layer, nnx.Param, ...)
429+
return carry, (params, rest)
422430

423-
return layers_vmapped
431+
_, (stacked_params, stacked_rest) = jax.lax.scan(scan_body, None, rngs_state)
432+
433+
# jax.lax.scan stacks outputs along axis 0. Move params to the configured scan axis.
434+
if scan_axis != 0:
435+
stacked_params = jax.tree.map(lambda x: jnp.moveaxis(x, 0, scan_axis), stacked_params)
436+
437+
# Add partition metadata that nnx.vmap's transform_metadata would normally set.
438+
# This metadata is read by variable_to_logically_partitioned() in initializers.py
439+
# to insert the scan axis name into logical sharding specs.
440+
def _add_partition_metadata(state):
441+
def _update(vs):
442+
if isinstance(vs, nnx.Variable):
443+
metadata = vs.get_metadata()
444+
return type(vs)(vs.get_value(), **{**metadata, nnx.PARTITION_NAME: metadata_axis_name, "param_scan_axis": scan_axis})
445+
return vs
446+
return jax.tree.map(_update, state, is_leaf=lambda x: isinstance(x, nnx.Variable))
447+
448+
stacked_params = _add_partition_metadata(stacked_params)
449+
stacked_rest = _add_partition_metadata(stacked_rest)
450+
451+
return nnx.merge(layer_graphdef, stacked_params, stacked_rest)
424452

425453
def _apply_layer_with_remat(self, layer: nnx.Module, y: jax.Array, policy: Any, prevent_cse: bool, **kwargs):
426454
"""Helper to cleanly apply jax.checkpoint to a single unscanned layer or block."""

0 commit comments

Comments
 (0)