@@ -506,30 +506,36 @@ def apply_lora_to_model(
506506 model_rngs = getattr (model .decoder , "rngs" , None )
507507 decoder_input_tokens , decoder_positions = _prepare_dummy_inputs ()
508508
509- lora_model = qwix .apply_lora_to_model (
510- model ,
511- lora_provider ,
512- decoder_input_tokens = decoder_input_tokens ,
513- decoder_positions = decoder_positions ,
514- rngs = model_rngs ,
515- )
509+ # Trigger materialization with Python loop fallback
510+ model .decoder .disable_quant_stats_update = True
511+ try :
512+ lora_model = qwix .apply_lora_to_model (
513+ model ,
514+ lora_provider ,
515+ decoder_input_tokens = decoder_input_tokens ,
516+ decoder_positions = decoder_positions ,
517+ rngs = model_rngs ,
518+ )
519+ finally :
520+ model .decoder .disable_quant_stats_update = False
521+
522+ model = lora_model
516523
517524 if mesh is not None :
518525 with mesh , nn_partitioning .axis_rules (mt_config .logical_axis_rules ):
519- graph_def , state = nnx .split (lora_model )
526+ graph_def , state = nnx .split (model )
520527 default_memory_kind = jax .devices ()[0 ].default_memory ().kind
521528 dst_shardings = jax .tree .map (
522529 lambda x : jax .sharding .NamedSharding (mesh , x , memory_kind = default_memory_kind ) if x is not None else None ,
523530 nnx .get_partition_spec (state ),
524531 )
525532 from tunix .rl import reshard # pylint: disable=import-outside-toplevel
526-
527533 state = reshard .reshard_pytree (state , dst_shardings )
528- lora_model = nnx .merge (graph_def , state )
534+ model = nnx .merge (graph_def , state )
529535
530- _verify_lora_parameters (lora_model , mt_config )
536+ _verify_lora_parameters (model , mt_config )
531537
532- return lora_model
538+ return model
533539
534540
535541def restore_lora_from_path (trainer : Any , mt_config : pyconfig .HyperParameters ) -> Any :
0 commit comments