Skip to content

Commit 670e85c

Browse files
committed
fix: remove __class__ swap in lora_utils to fix training loop crash, delete redundant sharding code
1 parent e96d2bb commit 670e85c

1 file changed

Lines changed: 7 additions & 51 deletions

File tree

src/maxtext/utils/lora_utils.py

Lines changed: 7 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -590,7 +590,6 @@ def apply_lora_to_model(
590590
model_rngs = getattr(model.decoder, "rngs", None)
591591
decoder_input_tokens, decoder_positions = _prepare_dummy_inputs(mt_config)
592592

593-
print(f"[DEBUG] Starting Qwix materialization on model type {type(model)}")
594593
# Trigger materialization with Python loop fallback
595594
model.decoder.disable_quant_stats_update = True
596595
try:
@@ -601,66 +600,23 @@ def apply_lora_to_model(
601600
decoder_positions=decoder_positions,
602601
rngs=model_rngs,
603602
)
604-
print(f"[DEBUG] Qwix call complete. Returned model type: {type(lora_model)}")
605603
finally:
606604
model.decoder.disable_quant_stats_update = False
607605

608-
# Important: Qwix dynamically swaps the __class__ of the model, which breaks nnx.iter_graph
609-
# We must restore the original unquantized class type for Tunix to recognize the module correctly.
610-
if hasattr(lora_model, "_unquantized_type"):
611-
lora_model.__class__ = getattr(lora_model, "_unquantized_type")
612-
613606
model = lora_model
614607

615-
def rank_consistent_spec(spec, shape):
616-
if spec is None: return None
617-
spec_list = list(spec)
618-
if len(shape) < len(spec_list):
619-
for axis_name in ["layers", "stage"]:
620-
while axis_name in spec_list and len(spec_list) > len(shape):
621-
spec_list.remove(axis_name)
622-
if len(spec_list) > len(shape):
623-
spec_list = spec_list[-len(shape):]
624-
elif len(shape) > len(spec_list):
625-
spec_list = [None] * (len(shape) - len(spec_list)) + spec_list
626-
return jax.sharding.PartitionSpec(*spec_list)
627-
628608
if mesh is not None:
629609
with mesh, nn_partitioning.axis_rules(mt_config.logical_axis_rules):
630610
graph_def, state = nnx.split(model)
631611

632-
def fix_metadata(x):
633-
if hasattr(x, "get_metadata") and hasattr(x, "replace"):
634-
metadata = x.get_metadata()
635-
sharding_spec = metadata.get("sharding") or metadata.get("out_sharding")
636-
if sharding_spec:
637-
new_spec = rank_consistent_spec(sharding_spec, x.value.shape)
638-
x = x.replace(sharding=new_spec, out_sharding=new_spec)
639-
try:
640-
from maxtext.utils import sharding as mt_sharding
641-
physical_sharding = mt_sharding.create_sharding(mesh, new_spec)
642-
x.value = jax.device_put(x.value, physical_sharding)
643-
except Exception: pass
644-
return x
645-
646-
state = jax.tree.map(fix_metadata, state)
647-
648-
def force_sharding_on_device(x):
649-
if hasattr(x, "get_metadata") and hasattr(x, "value"):
650-
metadata = x.get_metadata()
651-
spec = metadata.get("sharding") or metadata.get("out_sharding")
652-
if spec:
653-
try:
654-
from maxtext.utils import sharding as mt_sharding
655-
# Force rank-consistent physical sharding
656-
physical_sharding = mt_sharding.create_sharding(mesh, spec)
657-
x.value = jax.device_put(x.value, physical_sharding)
658-
except Exception: pass
659-
return x
660-
661-
is_nnx_leaf = lambda x: hasattr(x, "get_metadata") and hasattr(x, "replace")
662-
state = jax.tree.map(force_sharding_on_device, state, is_leaf=is_nnx_leaf)
612+
default_memory_kind = jax.devices()[0].default_memory().kind
613+
dst_shardings = jax.tree.map(
614+
lambda x: jax.sharding.NamedSharding(mesh, x, memory_kind=default_memory_kind) if x is not None else None,
615+
nnx.get_partition_spec(state),
616+
)
663617

618+
from tunix.rl import reshard # pylint: disable=import-outside-toplevel
619+
state = reshard.reshard_pytree(state, dst_shardings)
664620
model = nnx.merge(graph_def, state)
665621

666622
_verify_lora_parameters(model, mt_config)

0 commit comments

Comments
 (0)