Skip to content

Commit c1b625c

Browse files
author
Charles Li
committed
Fix out_sharding issue
1 parent bde4876 commit c1b625c

2 files changed

Lines changed: 22 additions & 11 deletions

File tree

src/maxtext/layers/initializers.py

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -96,16 +96,13 @@ def variable_to_logically_partitioned(variable: nnx.VariableState):
9696
if out_sharding is not None:
9797
if nnx.PARTITION_NAME in metadata:
9898
partition_name = metadata[nnx.PARTITION_NAME]
99-
# Only nnx.Param variables are typically scanned across the param_scan_axis
10099
scan_axis = metadata.get("param_scan_axis", 0) if variable.type == nnx.Param else 0
101100

102-
if isinstance(out_sharding, str):
103-
out_sharding = [out_sharding]
104-
else:
105-
out_sharding = list(out_sharding)
101+
sharding_list = [out_sharding] if isinstance(out_sharding, str) else list(out_sharding)
102+
if partition_name not in sharding_list:
103+
sharding_list.insert(scan_axis, partition_name)
106104

107-
out_sharding.insert(scan_axis, partition_name)
108-
out_sharding = tuple(out_sharding)
105+
out_sharding = tuple(sharding_list)
109106

110107
return nn.LogicallyPartitioned( # type: ignore[wrong-keyword-args]
111108
variable.value,

src/maxtext/layers/nnx_decoders.py

Lines changed: 18 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,7 @@
7171

7272
class NNXDecoderLayer(nnx.Module):
7373
"""
74-
Transformer decoder layer converted to NNX.
74+
Transformer decoder layer converted to NNX
7575
"""
7676

7777
def __init__(
@@ -451,13 +451,21 @@ def scan_body(carry, rng_state_slice):
451451

452452
# Add partition metadata that nnx.vmap's transform_metadata would normally set.
453453
# This metadata is read by variable_to_logically_partitioned() in initializers.py
454-
# to insert the scan axis name into logical sharding specs.
454+
# and by nnx.get_partition_spec() (via the updated out_sharding) to produce
455+
# correct sharding specs that include the scan axis dimension.
455456
def _add_scan_metadata(state, axis):
456457
def _update_leaf(leaf):
457458
if isinstance(leaf, nnx.VariableState):
458459
metadata = leaf.get_metadata()
459460
metadata[nnx.PARTITION_NAME] = metadata_axis_name
460461
metadata["param_scan_axis"] = axis
462+
# Insert the scan axis name into out_sharding so that
463+
# nnx.get_partition_spec returns specs matching the actual tensor rank.
464+
# Without this, scanned params are 3D but specs remain 2D.
465+
if "out_sharding" in metadata and metadata["out_sharding"]:
466+
out_sharding = list(metadata["out_sharding"])
467+
out_sharding.insert(axis, metadata_axis_name)
468+
metadata["out_sharding"] = tuple(out_sharding)
461469
return leaf.replace(**metadata)
462470
return leaf
463471

@@ -529,7 +537,13 @@ def layer_fn(carry, scanned_vars):
529537
params = jax.tree.map(lambda x: jnp.moveaxis(x, 0, scan_axis), params)
530538

531539
scanned_state = nnx.State.merge(params, scanned_other)
532-
return final_carry, nnx.merge(graphdef, scanned_state)
540+
# Update the existing module in-place rather than creating a new one.
541+
# Creating a new module via nnx.merge and reassigning (self.layers = new_module)
542+
# would replace a child node in the NNX graph, which is detected as a graph
543+
# structure mutation when the parent module is inside a JAX transformation
544+
# (e.g., nnx.jit in PeftTrainer). In-place update preserves object identity.
545+
nnx.update(layers, scanned_state)
546+
return final_carry, layers
533547

534548
def get_decoder_layers(self):
535549
"""Retrieves decoder layer classes based on config using a dictionary lookup."""
@@ -1217,7 +1231,7 @@ def decoder_as_linen(
12171231
model_mode: str,
12181232
quant: None | Quant = None,
12191233
):
1220-
"""Creates a Decoder module."""
1234+
"""Creates a Decoder module"""
12211235
module = nnx_wrappers.to_linen(
12221236
NNXDecoder,
12231237
config=config,

0 commit comments

Comments
 (0)