|
26 | 26 | from flax.core import FrozenDict |
27 | 27 | from flax.core import meta |
28 | 28 | from flax.nnx import graph |
29 | | -from flax.nnx import tracers as nnx_tracers |
30 | 29 | from flax.nnx import variablelib |
31 | 30 | from flax.nnx.bridge import module as bdg_module |
32 | 31 | from flax.nnx.module import Module |
@@ -180,19 +179,6 @@ def is_linen_initializing() -> bool: |
180 | 179 | return False |
181 | 180 |
|
182 | 181 |
|
183 | | -def _refresh_variable_trace_state(module: Module) -> None: |
184 | | - """Resets stale ``_trace_state`` on Variables to unblock downstream ``nnx.split``. |
185 | | -
|
186 | | - ``nnx.update`` called with JAX tracer values uses ``_unsafe_bypass_check=True``, |
187 | | - which leaves Variables with a stale ``_trace_state`` from the outer Python |
188 | | - context and breaks ``nnx.split`` with "Cannot extract graph node from different |
189 | | - trace level". Resets ``_trace_state`` on any Variable whose ``_can_update`` is False. |
190 | | - """ |
191 | | - for _, v in nnx.graph.iter_graph(module): |
192 | | - if isinstance(v, variablelib.Variable) and not v._can_update: # pylint: disable=protected-access |
193 | | - object.__setattr__(v, "_trace_state", nnx_tracers.TraceState()) |
194 | | - |
195 | | - |
196 | 182 | class ToNNX(Module): |
197 | 183 | """A wrapper to turn any Linen module into an NNX module. |
198 | 184 |
|
@@ -505,8 +491,11 @@ def maybe_unbox(x): |
505 | 491 | filtered_state_flat = {k: v for k, v in new_state_flat.items() if k not in unknown_state_flat} |
506 | 492 | new_state = nnx.State(nnx.traversals.unflatten_mapping(filtered_state_flat)) |
507 | 493 |
|
508 | | - nnx.update(module, new_state) |
509 | | - _refresh_variable_trace_state(module) |
| 494 | + # Use split and merge to create a new module bound to the current trace level |
| 495 | + # instead of using nnx.update which can leave stale tracers. |
| 496 | + _, graphdef = module.split() |
| 497 | + module = graphdef.merge(new_state) |
| 498 | + |
510 | 499 | _fix_for_qwix_quantization(module) |
511 | 500 | method_fn = _get_module_method(module, nnx_method) |
512 | 501 | out = method_fn(module, *args, **kwargs) |
|
0 commit comments