Skip to content

Commit 741c59f

Browse files
fix(qwix): Dynamically reconstruct unknown variables in ToLinen wrapper
1 parent 7c1bf78 commit 741c59f

1 file changed

Lines changed: 12 additions & 3 deletions

File tree

src/maxtext/layers/nnx_wrappers.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -496,10 +496,19 @@ def maybe_unbox(x):
496496

497497
if unknown_state_flat:
498498
paths_str = ""
499-
for path, _ in unknown_state_flat.items():
499+
for path, value_state in unknown_state_flat.items():
500500
paths_str += f"\n - {'/'.join(map(str, path))}"
501-
502-
warnings.warn(f"Found unknown module paths in incoming state:{paths_str}")
501+
502+
# Dynamically reconstruct the unknown variables
503+
curr = module
504+
for p in path[:-1]:
505+
if not hasattr(curr, p):
506+
setattr(curr, p, nnx.Module())
507+
curr = getattr(curr, p)
508+
if not hasattr(curr, path[-1]):
509+
setattr(curr, path[-1], value_state.type(value_state.value))
510+
511+
warnings.warn(f"Found unknown module paths in incoming state:{paths_str}. They have been dynamically reconstructed.")
503512

504513
nnx.update(module, new_state)
505514
_refresh_variable_trace_state(module)

0 commit comments

Comments
 (0)