@@ -628,7 +628,6 @@ def apply_lora_to_model(
628628 model_rngs = getattr (model .decoder , "rngs" , None )
629629 decoder_input_tokens , decoder_positions = _prepare_dummy_inputs (mt_config )
630630
631- print (f"[DEBUG] Starting Qwix materialization on model type { type (model )} " )
632631 # Trigger materialization with Python loop fallback
633632 model .decoder .disable_quant_stats_update = True
634633 try :
@@ -639,66 +638,23 @@ def apply_lora_to_model(
639638 decoder_positions = decoder_positions ,
640639 rngs = model_rngs ,
641640 )
642- print (f"[DEBUG] Qwix call complete. Returned model type: { type (lora_model )} " )
643641 finally :
644642 model .decoder .disable_quant_stats_update = False
645643
646- # Important: Qwix dynamically swaps the __class__ of the model, which breaks nnx.iter_graph
647- # We must restore the original unquantized class type for Tunix to recognize the module correctly.
648- if hasattr (lora_model , "_unquantized_type" ):
649- lora_model .__class__ = getattr (lora_model , "_unquantized_type" )
650-
651644 model = lora_model
652645
653- def rank_consistent_spec (spec , shape ):
654- if spec is None : return None
655- spec_list = list (spec )
656- if len (shape ) < len (spec_list ):
657- for axis_name in ["layers" , "stage" ]:
658- while axis_name in spec_list and len (spec_list ) > len (shape ):
659- spec_list .remove (axis_name )
660- if len (spec_list ) > len (shape ):
661- spec_list = spec_list [- len (shape ):]
662- elif len (shape ) > len (spec_list ):
663- spec_list = [None ] * (len (shape ) - len (spec_list )) + spec_list
664- return jax .sharding .PartitionSpec (* spec_list )
665-
666646 if mesh is not None :
667647 with mesh , nn_partitioning .axis_rules (mt_config .logical_axis_rules ):
668648 graph_def , state = nnx .split (model )
669649
670- def fix_metadata (x ):
671- if hasattr (x , "get_metadata" ) and hasattr (x , "replace" ):
672- metadata = x .get_metadata ()
673- sharding_spec = metadata .get ("sharding" ) or metadata .get ("out_sharding" )
674- if sharding_spec :
675- new_spec = rank_consistent_spec (sharding_spec , x .value .shape )
676- x = x .replace (sharding = new_spec , out_sharding = new_spec )
677- try :
678- from maxtext .utils import sharding as mt_sharding
679- physical_sharding = mt_sharding .create_sharding (mesh , new_spec )
680- x .value = jax .device_put (x .value , physical_sharding )
681- except Exception : pass
682- return x
683-
684- state = jax .tree .map (fix_metadata , state )
685-
686- def force_sharding_on_device (x ):
687- if hasattr (x , "get_metadata" ) and hasattr (x , "value" ):
688- metadata = x .get_metadata ()
689- spec = metadata .get ("sharding" ) or metadata .get ("out_sharding" )
690- if spec :
691- try :
692- from maxtext .utils import sharding as mt_sharding
693- # Force rank-consistent physical sharding
694- physical_sharding = mt_sharding .create_sharding (mesh , spec )
695- x .value = jax .device_put (x .value , physical_sharding )
696- except Exception : pass
697- return x
698-
699- is_nnx_leaf = lambda x : hasattr (x , "get_metadata" ) and hasattr (x , "replace" )
700- state = jax .tree .map (force_sharding_on_device , state , is_leaf = is_nnx_leaf )
650+ default_memory_kind = jax .devices ()[0 ].default_memory ().kind
651+ dst_shardings = jax .tree .map (
652+ lambda x : jax .sharding .NamedSharding (mesh , x , memory_kind = default_memory_kind ) if x is not None else None ,
653+ nnx .get_partition_spec (state ),
654+ )
701655
656+ from tunix .rl import reshard # pylint: disable=import-outside-toplevel
657+ state = reshard .reshard_pytree (state , dst_shardings )
702658 model = nnx .merge (graph_def , state )
703659
704660 _verify_lora_parameters (model , mt_config )
0 commit comments