Skip to content

Commit e96d2bb

Browse files
committed
chore: simplify LORA logic and remove debug prints
1 parent 114b53a commit e96d2bb

2 files changed

Lines changed: 14 additions & 56 deletions

File tree

src/maxtext/layers/nnx_decoders.py

Lines changed: 2 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -589,23 +589,8 @@ def _promote_to_scanned(x):
589589
curr_params = jax.tree.map(lambda x: _slice_and_unpromote(x, i), params, is_leaf=is_nnx_leaf)
590590
curr_state = jax.tree.map(lambda x: _slice_and_unpromote(x, i), state, is_leaf=is_nnx_leaf)
591591

592-
if i == 0:
593-
print(f"[DEBUG] Available nodes in Iter 0: {list(curr_params.keys())}")
594-
595-
print(f"[DEBUG] Iteration {i}: calling layer_fn")
596592
curr_carry, (out_p, out_o) = layer_fn(curr_carry, (curr_params, curr_state))
597593

598-
# Inspect results for new LoRA params
599-
if i == 0:
600-
def find_lora(p, path=""):
601-
if isinstance(p, nnx.State):
602-
for k, v in p.items(): find_lora(v, f"{path}/{k}")
603-
elif isinstance(p, dict):
604-
for k, v in p.items(): find_lora(v, f"{path}/{k}")
605-
elif hasattr(p, "__class__") and p.__class__.__name__ == "LoRAParam":
606-
print(f"[DEBUG] FOUND NEW LORA PARAM at {path}")
607-
find_lora(out_p, "params")
608-
609594
# Promote ALL parameters back to rank-3 metadata immediately
610595
# This ensures they are ready to be stacked correctly.
611596
out_p = jax.tree.map(_promote_to_scanned, out_p, is_leaf=is_nnx_leaf)
@@ -615,7 +600,6 @@ def find_lora(p, path=""):
615600
out_other_list.append(out_o)
616601

617602
final_carry = curr_carry
618-
print(f"[DEBUG] Loop complete, stacking results and promoting metadata...")
619603
scanned_params = jax.tree.map(lambda *args: jnp.stack(args), *out_params_list)
620604
scanned_other = jax.tree.map(lambda *args: jnp.stack(args), *out_other_list)
621605

@@ -627,11 +611,12 @@ def _force_promote(x):
627611
if is_nnx_leaf:
628612
metadata = x.get_metadata()
629613
updates = {}
614+
val_ndim = x.value.ndim
630615
for sharding_key in ["sharding", "out_sharding", "sharding_names"]:
631616
axes = metadata.get(sharding_key)
632617
if isinstance(axes, (jax.sharding.PartitionSpec, tuple, list)):
633618
l = list(axes)
634-
if "layers" not in l:
619+
if len(l) < val_ndim and "layers" not in l:
635620
pos = min(scan_axis, len(l))
636621
l.insert(pos, "layers")
637622
updates[sharding_key] = jax.sharding.PartitionSpec(*l) if isinstance(axes, jax.sharding.PartitionSpec) else tuple(l)

src/maxtext/utils/lora_utils.py

Lines changed: 12 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -436,31 +436,6 @@ def _verify_lora_parameters(lora_model: nnx.Module, mt_config: pyconfig.HyperPar
436436

437437
matched_module_paths = []
438438
sample_module_paths = []
439-
found_lora = False
440-
seen = set()
441-
442-
# Truly recursive search to find LoRAParam regardless of NNX registration state
443-
def recursive_find_lora(obj):
444-
nonlocal found_lora
445-
if found_lora or id(obj) in seen: return
446-
seen.add(id(obj))
447-
448-
if hasattr(obj, "__class__") and obj.__class__.__name__ == "LoRAParam":
449-
found_lora = True
450-
return
451-
452-
if hasattr(obj, "__dict__"):
453-
for k, v in obj.__dict__.items():
454-
if not k.startswith("__"):
455-
recursive_find_lora(v)
456-
elif isinstance(obj, (dict, list, tuple)):
457-
items = obj.values() if isinstance(obj, dict) else obj
458-
for v in items: recursive_find_lora(v)
459-
460-
recursive_find_lora(lora_model)
461-
462-
if found_lora:
463-
return
464439

465440
for path, _ in nnx.iter_graph(lora_model):
466441
module_path = "/".join(str(p) for p in path)
@@ -578,6 +553,12 @@ def patched_get_or_create_lora_params(*, name, rule, a_shape, b_shape, a_shardin
578553
b_sharding_transpose=b_sharding_transpose,
579554
)
580555

556+
# Ensure they are specifically LoRAParam, not just generic Param or Variable
557+
if hasattr(lora_a, "value") and hasattr(lora_a, "get_metadata"):
558+
lora_a = nnx.LoRAParam(lora_a.value, **lora_a.get_metadata())
559+
if hasattr(lora_b, "value") and hasattr(lora_b, "get_metadata"):
560+
lora_b = nnx.LoRAParam(lora_b.value, **lora_b.get_metadata())
561+
581562
# Force registration on the current module
582563
module = flax_util.get_current_module()
583564
if isinstance(module, nnx.Module):
@@ -624,20 +605,12 @@ def apply_lora_to_model(
624605
finally:
625606
model.decoder.disable_quant_stats_update = False
626607

627-
# Important: use the NEW model returned by Qwix!
608+
# Important: Qwix dynamically swaps the __class__ of the model, which breaks nnx.iter_graph
609+
# We must restore the original unquantized class type for Tunix to recognize the module correctly.
610+
if hasattr(lora_model, "_unquantized_type"):
611+
lora_model.__class__ = getattr(lora_model, "_unquantized_type")
612+
628613
model = lora_model
629-
630-
# Check if we can find lora in this model immediately
631-
temp_found = []
632-
def quick_check(obj, path=""):
633-
if len(temp_found) > 0: return
634-
if hasattr(obj, "__class__") and obj.__class__.__name__ == "LoRAParam":
635-
temp_found.append(path)
636-
if hasattr(obj, "__dict__"):
637-
for k,v in obj.__dict__.items():
638-
if not k.startswith("__"): quick_check(v, f"{path}/{k}")
639-
quick_check(model, "root")
640-
print(f"[DEBUG] Quick check for LoRA in lora_model: {temp_found}")
641614

642615
def rank_consistent_spec(spec, shape):
643616
if spec is None: return None
@@ -654,7 +627,7 @@ def rank_consistent_spec(spec, shape):
654627

655628
if mesh is not None:
656629
with mesh, nn_partitioning.axis_rules(mt_config.logical_axis_rules):
657-
graph_def, state = nnx.split(lora_model)
630+
graph_def, state = nnx.split(model)
658631

659632
def fix_metadata(x):
660633
if hasattr(x, "get_metadata") and hasattr(x, "replace"):

0 commit comments

Comments
 (0)