Skip to content

Commit 28ea6c4

Browse files
author
Sharon Yu
committed
fix nnx_wrapper.py comment
1 parent 5c07324 commit 28ea6c4

1 file changed

Lines changed: 5 additions & 16 deletions

File tree

src/maxtext/layers/nnx_wrappers.py

Lines changed: 5 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,6 @@
2626
from flax.core import FrozenDict
2727
from flax.core import meta
2828
from flax.nnx import graph
29-
from flax.nnx import tracers as nnx_tracers
3029
from flax.nnx import variablelib
3130
from flax.nnx.bridge import module as bdg_module
3231
from flax.nnx.module import Module
@@ -180,19 +179,6 @@ def is_linen_initializing() -> bool:
180179
return False
181180

182181

183-
def _refresh_variable_trace_state(module: Module) -> None:
184-
"""Resets stale ``_trace_state`` on Variables to unblock downstream ``nnx.split``.
185-
186-
``nnx.update`` called with JAX tracer values uses ``_unsafe_bypass_check=True``,
187-
which leaves Variables with a stale ``_trace_state`` from the outer Python
188-
context and breaks ``nnx.split`` with "Cannot extract graph node from different
189-
trace level". Resets ``_trace_state`` on any Variable whose ``_can_update`` is False.
190-
"""
191-
for _, v in nnx.graph.iter_graph(module):
192-
if isinstance(v, variablelib.Variable) and not v._can_update: # pylint: disable=protected-access
193-
object.__setattr__(v, "_trace_state", nnx_tracers.TraceState())
194-
195-
196182
class ToNNX(Module):
197183
"""A wrapper to turn any Linen module into an NNX module.
198184
@@ -505,8 +491,11 @@ def maybe_unbox(x):
505491
filtered_state_flat = {k: v for k, v in new_state_flat.items() if k not in unknown_state_flat}
506492
new_state = nnx.State(nnx.traversals.unflatten_mapping(filtered_state_flat))
507493

508-
nnx.update(module, new_state)
509-
_refresh_variable_trace_state(module)
494+
# Use split and merge to create a new module bound to the current trace level
495+
# instead of using nnx.update which can leave stale tracers.
496+
_, graphdef = module.split()
497+
module = graphdef.merge(new_state)
498+
510499
_fix_for_qwix_quantization(module)
511500
method_fn = _get_module_method(module, nnx_method)
512501
out = method_fn(module, *args, **kwargs)

0 commit comments

Comments
 (0)