Skip to content

Commit a2739ab

Browse files
committed
fix: remove __class__ swap in lora_utils to fix training loop crash, delete redundant sharding code
1 parent 76b64e0 commit a2739ab

1 file changed

Lines changed: 7 additions & 51 deletions

File tree

src/maxtext/utils/lora_utils.py

Lines changed: 7 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -628,7 +628,6 @@ def apply_lora_to_model(
628628
model_rngs = getattr(model.decoder, "rngs", None)
629629
decoder_input_tokens, decoder_positions = _prepare_dummy_inputs(mt_config)
630630

631-
print(f"[DEBUG] Starting Qwix materialization on model type {type(model)}")
632631
# Trigger materialization with Python loop fallback
633632
model.decoder.disable_quant_stats_update = True
634633
try:
@@ -639,66 +638,23 @@ def apply_lora_to_model(
639638
decoder_positions=decoder_positions,
640639
rngs=model_rngs,
641640
)
642-
print(f"[DEBUG] Qwix call complete. Returned model type: {type(lora_model)}")
643641
finally:
644642
model.decoder.disable_quant_stats_update = False
645643

646-
# Important: Qwix dynamically swaps the __class__ of the model, which breaks nnx.iter_graph
647-
# We must restore the original unquantized class type for Tunix to recognize the module correctly.
648-
if hasattr(lora_model, "_unquantized_type"):
649-
lora_model.__class__ = getattr(lora_model, "_unquantized_type")
650-
651644
model = lora_model
652645

653-
def rank_consistent_spec(spec, shape):
654-
if spec is None: return None
655-
spec_list = list(spec)
656-
if len(shape) < len(spec_list):
657-
for axis_name in ["layers", "stage"]:
658-
while axis_name in spec_list and len(spec_list) > len(shape):
659-
spec_list.remove(axis_name)
660-
if len(spec_list) > len(shape):
661-
spec_list = spec_list[-len(shape):]
662-
elif len(shape) > len(spec_list):
663-
spec_list = [None] * (len(shape) - len(spec_list)) + spec_list
664-
return jax.sharding.PartitionSpec(*spec_list)
665-
666646
if mesh is not None:
667647
with mesh, nn_partitioning.axis_rules(mt_config.logical_axis_rules):
668648
graph_def, state = nnx.split(model)
669649

670-
def fix_metadata(x):
671-
if hasattr(x, "get_metadata") and hasattr(x, "replace"):
672-
metadata = x.get_metadata()
673-
sharding_spec = metadata.get("sharding") or metadata.get("out_sharding")
674-
if sharding_spec:
675-
new_spec = rank_consistent_spec(sharding_spec, x.value.shape)
676-
x = x.replace(sharding=new_spec, out_sharding=new_spec)
677-
try:
678-
from maxtext.utils import sharding as mt_sharding
679-
physical_sharding = mt_sharding.create_sharding(mesh, new_spec)
680-
x.value = jax.device_put(x.value, physical_sharding)
681-
except Exception: pass
682-
return x
683-
684-
state = jax.tree.map(fix_metadata, state)
685-
686-
def force_sharding_on_device(x):
687-
if hasattr(x, "get_metadata") and hasattr(x, "value"):
688-
metadata = x.get_metadata()
689-
spec = metadata.get("sharding") or metadata.get("out_sharding")
690-
if spec:
691-
try:
692-
from maxtext.utils import sharding as mt_sharding
693-
# Force rank-consistent physical sharding
694-
physical_sharding = mt_sharding.create_sharding(mesh, spec)
695-
x.value = jax.device_put(x.value, physical_sharding)
696-
except Exception: pass
697-
return x
698-
699-
is_nnx_leaf = lambda x: hasattr(x, "get_metadata") and hasattr(x, "replace")
700-
state = jax.tree.map(force_sharding_on_device, state, is_leaf=is_nnx_leaf)
650+
default_memory_kind = jax.devices()[0].default_memory().kind
651+
dst_shardings = jax.tree.map(
652+
lambda x: jax.sharding.NamedSharding(mesh, x, memory_kind=default_memory_kind) if x is not None else None,
653+
nnx.get_partition_spec(state),
654+
)
701655

656+
from tunix.rl import reshard # pylint: disable=import-outside-toplevel
657+
state = reshard.reshard_pytree(state, dst_shardings)
702658
model = nnx.merge(graph_def, state)
703659

704660
_verify_lora_parameters(model, mt_config)

0 commit comments

Comments
 (0)