Skip to content

Commit 2d43dc6

Browse files
author
Sharon Yu
committed
fix nnx_wrapper.py gpu UT failure
1 parent 8412184 commit 2d43dc6

1 file changed

Lines changed: 5 additions & 3 deletions

File tree

src/maxtext/layers/nnx_wrappers.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -492,9 +492,11 @@ def maybe_unbox(x):
492492
new_state = nnx.State(nnx.traversals.unflatten_mapping(filtered_state_flat))
493493

494494
# Use split and merge to create a new module bound to the current trace level
495-
# instead of using nnx.update which can leave stale tracers.
496-
_, graphdef = module.split()
497-
module = graphdef.merge(new_state)
495+
# instead of using nnx.update directly on the module which can leave stale tracers.
496+
# We must merge with the full state, so we split first, update the full state, and merge.
497+
graphdef, full_state = nnx.split(module)
498+
nnx.update(full_state, new_state)
499+
module = nnx.merge(graphdef, full_state)
498500

499501
_fix_for_qwix_quantization(module)
500502
method_fn = _get_module_method(module, nnx_method)

0 commit comments

Comments
 (0)