@@ -590,7 +590,6 @@ def apply_lora_to_model(
590590 model_rngs = getattr (model .decoder , "rngs" , None )
591591 decoder_input_tokens , decoder_positions = _prepare_dummy_inputs (mt_config )
592592
593- print (f"[DEBUG] Starting Qwix materialization on model type { type (model )} " )
594593 # Trigger materialization with Python loop fallback
595594 model .decoder .disable_quant_stats_update = True
596595 try :
@@ -601,66 +600,23 @@ def apply_lora_to_model(
601600 decoder_positions = decoder_positions ,
602601 rngs = model_rngs ,
603602 )
604- print (f"[DEBUG] Qwix call complete. Returned model type: { type (lora_model )} " )
605603 finally :
606604 model .decoder .disable_quant_stats_update = False
607605
608- # Important: Qwix dynamically swaps the __class__ of the model, which breaks nnx.iter_graph
609- # We must restore the original unquantized class type for Tunix to recognize the module correctly.
610- if hasattr (lora_model , "_unquantized_type" ):
611- lora_model .__class__ = getattr (lora_model , "_unquantized_type" )
612-
613606 model = lora_model
614607
615- def rank_consistent_spec (spec , shape ):
616- if spec is None : return None
617- spec_list = list (spec )
618- if len (shape ) < len (spec_list ):
619- for axis_name in ["layers" , "stage" ]:
620- while axis_name in spec_list and len (spec_list ) > len (shape ):
621- spec_list .remove (axis_name )
622- if len (spec_list ) > len (shape ):
623- spec_list = spec_list [- len (shape ):]
624- elif len (shape ) > len (spec_list ):
625- spec_list = [None ] * (len (shape ) - len (spec_list )) + spec_list
626- return jax .sharding .PartitionSpec (* spec_list )
627-
628608 if mesh is not None :
629609 with mesh , nn_partitioning .axis_rules (mt_config .logical_axis_rules ):
630610 graph_def , state = nnx .split (model )
631611
632- def fix_metadata (x ):
633- if hasattr (x , "get_metadata" ) and hasattr (x , "replace" ):
634- metadata = x .get_metadata ()
635- sharding_spec = metadata .get ("sharding" ) or metadata .get ("out_sharding" )
636- if sharding_spec :
637- new_spec = rank_consistent_spec (sharding_spec , x .value .shape )
638- x = x .replace (sharding = new_spec , out_sharding = new_spec )
639- try :
640- from maxtext .utils import sharding as mt_sharding
641- physical_sharding = mt_sharding .create_sharding (mesh , new_spec )
642- x .value = jax .device_put (x .value , physical_sharding )
643- except Exception : pass
644- return x
645-
646- state = jax .tree .map (fix_metadata , state )
647-
648- def force_sharding_on_device (x ):
649- if hasattr (x , "get_metadata" ) and hasattr (x , "value" ):
650- metadata = x .get_metadata ()
651- spec = metadata .get ("sharding" ) or metadata .get ("out_sharding" )
652- if spec :
653- try :
654- from maxtext .utils import sharding as mt_sharding
655- # Force rank-consistent physical sharding
656- physical_sharding = mt_sharding .create_sharding (mesh , spec )
657- x .value = jax .device_put (x .value , physical_sharding )
658- except Exception : pass
659- return x
660-
661- is_nnx_leaf = lambda x : hasattr (x , "get_metadata" ) and hasattr (x , "replace" )
662- state = jax .tree .map (force_sharding_on_device , state , is_leaf = is_nnx_leaf )
612+ default_memory_kind = jax .devices ()[0 ].default_memory ().kind
613+ dst_shardings = jax .tree .map (
614+ lambda x : jax .sharding .NamedSharding (mesh , x , memory_kind = default_memory_kind ) if x is not None else None ,
615+ nnx .get_partition_spec (state ),
616+ )
663617
618+ from tunix .rl import reshard # pylint: disable=import-outside-toplevel
619+ state = reshard .reshard_pytree (state , dst_shardings )
664620 model = nnx .merge (graph_def , state )
665621
666622 _verify_lora_parameters (model , mt_config )
0 commit comments