@@ -451,13 +451,21 @@ def scan_body(carry, rng_state_slice):
451451
452452 # Add partition metadata that nnx.vmap's transform_metadata would normally set.
453453 # This metadata is read by variable_to_logically_partitioned() in initializers.py
454- # to insert the scan axis name into logical sharding specs.
454+ # and by nnx.get_partition_spec() (via the updated out_sharding) to produce
455+ # correct sharding specs that include the scan axis dimension.
455456 def _add_scan_metadata (state , axis ):
456457 def _update_leaf (leaf ):
457458 if isinstance (leaf , nnx .VariableState ):
458459 metadata = leaf .get_metadata ()
459460 metadata [nnx .PARTITION_NAME ] = metadata_axis_name
460461 metadata ["param_scan_axis" ] = axis
462+ # Insert the scan axis name into out_sharding so that
463+ # nnx.get_partition_spec returns specs matching the actual tensor rank.
464+ # Without this, scanned params are 3D but specs remain 2D.
465+ if "out_sharding" in metadata and metadata ["out_sharding" ]:
466+ sharding = list (metadata ["out_sharding" ])
467+ sharding .insert (axis , metadata_axis_name )
468+ metadata ["out_sharding" ] = tuple (sharding )
461469 return leaf .replace (** metadata )
462470 return leaf
463471
@@ -529,7 +537,13 @@ def _layer_fn(carry, scanned_vars):
529537 params = jax .tree .map (lambda x : jnp .moveaxis (x , 0 , scan_axis ), params )
530538
531539 scanned_state = nnx .State .merge (params , scanned_other )
532- return final_carry , nnx .merge (graphdef , scanned_state )
540+ # Update the existing module in-place rather than creating a new one.
541+ # Creating a new module via nnx.merge and reassigning (self.layers = new_module)
542+ # would replace a child node in the NNX graph, which is detected as a graph
543+ # structure mutation when the parent module is inside a JAX transformation
544+ # (e.g., nnx.jit in PeftTrainer). In-place update preserves object identity.
545+ nnx .update (layers , scanned_state )
546+ return final_carry , layers
533547
534548 def get_decoder_layers (self ):
535549 """Retrieves decoder layer classes based on config using a dictionary lookup."""
0 commit comments