|
71 | 71 |
|
72 | 72 | class NNXDecoderLayer(nnx.Module): |
73 | 73 | """ |
74 | | - Transformer decoder layer converted to NNX. |
| 74 | + Transformer decoder layer converted to NNX |
75 | 75 | """ |
76 | 76 |
|
77 | 77 | def __init__( |
@@ -451,13 +451,21 @@ def scan_body(carry, rng_state_slice): |
451 | 451 |
|
452 | 452 | # Add partition metadata that nnx.vmap's transform_metadata would normally set. |
453 | 453 | # This metadata is read by variable_to_logically_partitioned() in initializers.py |
454 | | - # to insert the scan axis name into logical sharding specs. |
| 454 | + # and by nnx.get_partition_spec() (via the updated out_sharding) to produce |
| 455 | + # correct sharding specs that include the scan axis dimension. |
455 | 456 | def _add_scan_metadata(state, axis): |
456 | 457 | def _update_leaf(leaf): |
457 | 458 | if isinstance(leaf, nnx.VariableState): |
458 | 459 | metadata = leaf.get_metadata() |
459 | 460 | metadata[nnx.PARTITION_NAME] = metadata_axis_name |
460 | 461 | metadata["param_scan_axis"] = axis |
| 462 | + # Insert the scan axis name into out_sharding so that |
| 463 | + # nnx.get_partition_spec returns specs matching the actual tensor rank. |
| 464 | + # Without this, scanned params are 3D but specs remain 2D. |
| 465 | + if "out_sharding" in metadata and metadata["out_sharding"]: |
| 466 | + out_sharding = list(metadata["out_sharding"]) |
| 467 | + out_sharding.insert(axis, metadata_axis_name) |
| 468 | + metadata["out_sharding"] = tuple(out_sharding) |
461 | 469 | return leaf.replace(**metadata) |
462 | 470 | return leaf |
463 | 471 |
|
@@ -529,7 +537,13 @@ def layer_fn(carry, scanned_vars): |
529 | 537 | params = jax.tree.map(lambda x: jnp.moveaxis(x, 0, scan_axis), params) |
530 | 538 |
|
531 | 539 | scanned_state = nnx.State.merge(params, scanned_other) |
532 | | - return final_carry, nnx.merge(graphdef, scanned_state) |
| 540 | + # Update the existing module in-place rather than creating a new one. |
| 541 | + # Creating a new module via nnx.merge and reassigning (self.layers = new_module) |
| 542 | + # would replace a child node in the NNX graph, which is detected as a graph |
| 543 | + # structure mutation when the parent module is inside a JAX transformation |
| 544 | + # (e.g., nnx.jit in PeftTrainer). In-place update preserves object identity. |
| 545 | + nnx.update(layers, scanned_state) |
| 546 | + return final_carry, layers |
533 | 547 |
|
534 | 548 | def get_decoder_layers(self): |
535 | 549 | """Retrieves decoder layer classes based on config using a dictionary lookup.""" |
@@ -1217,7 +1231,7 @@ def decoder_as_linen( |
1217 | 1231 | model_mode: str, |
1218 | 1232 | quant: None | Quant = None, |
1219 | 1233 | ): |
1220 | | - """Creates a Decoder module.""" |
| 1234 | + """Creates a Decoder module""" |
1221 | 1235 | module = nnx_wrappers.to_linen( |
1222 | 1236 | NNXDecoder, |
1223 | 1237 | config=config, |
|
0 commit comments