Skip to content

Commit 7c42a30

Browse files
author
Charles Li
committed
ad out_sharding in metadata
1 parent 8e17656 commit 7c42a30

1 file changed

Lines changed: 16 additions & 2 deletions

File tree

src/maxtext/layers/nnx_decoders.py

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -451,13 +451,21 @@ def scan_body(carry, rng_state_slice):
451451

452452
# Add partition metadata that nnx.vmap's transform_metadata would normally set.
453453
# This metadata is read by variable_to_logically_partitioned() in initializers.py
454-
# to insert the scan axis name into logical sharding specs.
454+
# and by nnx.get_partition_spec() (via the updated out_sharding) to produce
455+
# correct sharding specs that include the scan axis dimension.
455456
def _add_scan_metadata(state, axis):
456457
def _update_leaf(leaf):
457458
if isinstance(leaf, nnx.VariableState):
458459
metadata = leaf.get_metadata()
459460
metadata[nnx.PARTITION_NAME] = metadata_axis_name
460461
metadata["param_scan_axis"] = axis
462+
# Insert the scan axis name into out_sharding so that
463+
# nnx.get_partition_spec returns specs matching the actual tensor rank.
464+
# Without this, scanned params are 3D but specs remain 2D.
465+
if "out_sharding" in metadata and metadata["out_sharding"]:
466+
sharding = list(metadata["out_sharding"])
467+
sharding.insert(axis, metadata_axis_name)
468+
metadata["out_sharding"] = tuple(sharding)
461469
return leaf.replace(**metadata)
462470
return leaf
463471

@@ -529,7 +537,13 @@ def _layer_fn(carry, scanned_vars):
529537
params = jax.tree.map(lambda x: jnp.moveaxis(x, 0, scan_axis), params)
530538

531539
scanned_state = nnx.State.merge(params, scanned_other)
532-
return final_carry, nnx.merge(graphdef, scanned_state)
540+
# Update the existing module in-place rather than creating a new one.
541+
# Creating a new module via nnx.merge and reassigning (self.layers = new_module)
542+
# would replace a child node in the NNX graph, which is detected as a graph
543+
# structure mutation when the parent module is inside a JAX transformation
544+
# (e.g., nnx.jit in PeftTrainer). In-place update preserves object identity.
545+
nnx.update(layers, scanned_state)
546+
return final_carry, layers
533547

534548
def get_decoder_layers(self):
535549
"""Retrieves decoder layer classes based on config using a dictionary lookup."""

0 commit comments

Comments
 (0)