Skip to content

Commit 672c9d0

Browse files
committed
fix: remove __class__ swap in lora_utils to fix training loop crash, delete redundant sharding code
1 parent f5736a1 commit 672c9d0

1 file changed

Lines changed: 18 additions & 12 deletions

File tree

src/maxtext/utils/lora_utils.py

Lines changed: 18 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -506,30 +506,36 @@ def apply_lora_to_model(
506506
model_rngs = getattr(model.decoder, "rngs", None)
507507
decoder_input_tokens, decoder_positions = _prepare_dummy_inputs()
508508

509-
lora_model = qwix.apply_lora_to_model(
510-
model,
511-
lora_provider,
512-
decoder_input_tokens=decoder_input_tokens,
513-
decoder_positions=decoder_positions,
514-
rngs=model_rngs,
515-
)
509+
# Trigger materialization with Python loop fallback
510+
model.decoder.disable_quant_stats_update = True
511+
try:
512+
lora_model = qwix.apply_lora_to_model(
513+
model,
514+
lora_provider,
515+
decoder_input_tokens=decoder_input_tokens,
516+
decoder_positions=decoder_positions,
517+
rngs=model_rngs,
518+
)
519+
finally:
520+
model.decoder.disable_quant_stats_update = False
521+
522+
model = lora_model
516523

517524
if mesh is not None:
518525
with mesh, nn_partitioning.axis_rules(mt_config.logical_axis_rules):
519-
graph_def, state = nnx.split(lora_model)
526+
graph_def, state = nnx.split(model)
520527
default_memory_kind = jax.devices()[0].default_memory().kind
521528
dst_shardings = jax.tree.map(
522529
lambda x: jax.sharding.NamedSharding(mesh, x, memory_kind=default_memory_kind) if x is not None else None,
523530
nnx.get_partition_spec(state),
524531
)
525532
from tunix.rl import reshard # pylint: disable=import-outside-toplevel
526-
527533
state = reshard.reshard_pytree(state, dst_shardings)
528-
lora_model = nnx.merge(graph_def, state)
534+
model = nnx.merge(graph_def, state)
529535

530-
_verify_lora_parameters(lora_model, mt_config)
536+
_verify_lora_parameters(model, mt_config)
531537

532-
return lora_model
538+
return model
533539

534540

535541
def restore_lora_from_path(trainer: Any, mt_config: pyconfig.HyperParameters) -> Any:

0 commit comments

Comments
 (0)