@@ -436,31 +436,6 @@ def _verify_lora_parameters(lora_model: nnx.Module, mt_config: pyconfig.HyperPar
436436
437437 matched_module_paths = []
438438 sample_module_paths = []
439- found_lora = False
440- seen = set ()
441-
442- # Truly recursive search to find LoRAParam regardless of NNX registration state
443- def recursive_find_lora (obj ):
444- nonlocal found_lora
445- if found_lora or id (obj ) in seen : return
446- seen .add (id (obj ))
447-
448- if hasattr (obj , "__class__" ) and obj .__class__ .__name__ == "LoRAParam" :
449- found_lora = True
450- return
451-
452- if hasattr (obj , "__dict__" ):
453- for k , v in obj .__dict__ .items ():
454- if not k .startswith ("__" ):
455- recursive_find_lora (v )
456- elif isinstance (obj , (dict , list , tuple )):
457- items = obj .values () if isinstance (obj , dict ) else obj
458- for v in items : recursive_find_lora (v )
459-
460- recursive_find_lora (lora_model )
461-
462- if found_lora :
463- return
464439
465440 for path , _ in nnx .iter_graph (lora_model ):
466441 module_path = "/" .join (str (p ) for p in path )
@@ -578,6 +553,12 @@ def patched_get_or_create_lora_params(*, name, rule, a_shape, b_shape, a_shardin
578553 b_sharding_transpose = b_sharding_transpose ,
579554 )
580555
556+ # Ensure they are specifically LoRAParam, not just generic Param or Variable
557+ if hasattr (lora_a , "value" ) and hasattr (lora_a , "get_metadata" ):
558+ lora_a = nnx .LoRAParam (lora_a .value , ** lora_a .get_metadata ())
559+ if hasattr (lora_b , "value" ) and hasattr (lora_b , "get_metadata" ):
560+ lora_b = nnx .LoRAParam (lora_b .value , ** lora_b .get_metadata ())
561+
581562 # Force registration on the current module
582563 module = flax_util .get_current_module ()
583564 if isinstance (module , nnx .Module ):
@@ -624,20 +605,12 @@ def apply_lora_to_model(
624605 finally :
625606 model .decoder .disable_quant_stats_update = False
626607
627- # Important: use the NEW model returned by Qwix!
608+ # Important: Qwix dynamically swaps the __class__ of the model, which breaks nnx.iter_graph
609+ # We must restore the original unquantized class type for Tunix to recognize the module correctly.
610+ if hasattr (lora_model , "_unquantized_type" ):
611+ lora_model .__class__ = getattr (lora_model , "_unquantized_type" )
612+
628613 model = lora_model
629-
630- # Check if we can find lora in this model immediately
631- temp_found = []
632- def quick_check (obj , path = "" ):
633- if len (temp_found ) > 0 : return
634- if hasattr (obj , "__class__" ) and obj .__class__ .__name__ == "LoRAParam" :
635- temp_found .append (path )
636- if hasattr (obj , "__dict__" ):
637- for k ,v in obj .__dict__ .items ():
638- if not k .startswith ("__" ): quick_check (v , f"{ path } /{ k } " )
639- quick_check (model , "root" )
640- print (f"[DEBUG] Quick check for LoRA in lora_model: { temp_found } " )
641614
642615 def rank_consistent_spec (spec , shape ):
643616 if spec is None : return None
@@ -654,7 +627,7 @@ def rank_consistent_spec(spec, shape):
654627
655628 if mesh is not None :
656629 with mesh , nn_partitioning .axis_rules (mt_config .logical_axis_rules ):
657- graph_def , state = nnx .split (lora_model )
630+ graph_def , state = nnx .split (model )
658631
659632 def fix_metadata (x ):
660633 if hasattr (x , "get_metadata" ) and hasattr (x , "replace" ):
0 commit comments