1212# See the License for the specific language governing permissions and
1313# limitations under the License.
1414
15- """ Common LoRA utils needed to support LoRA adapters."""
15+ """Common LoRA utils needed to support LoRA adapters."""
1616from functools import partial
1717import json
1818import os
@@ -385,14 +385,20 @@ def _get_lora_module_path(mt_config: pyconfig.HyperParameters) -> str:
385385 model_name = mt_config .model_name .lower ()
386386
387387 # Find the first matching architecture prefix or use 'default'
388- matched_key = next ((k for k in lora_configs if k != "default" and model_name .startswith (k )), "default" )
388+ matched_key = next (
389+ (k for k in lora_configs if k != "default" and model_name .startswith (k )),
390+ "default" ,
391+ )
389392
390393 if matched_key == "default" :
391394 max_logging .log (f"Warning: Model '{ model_name } ' is unverified; falling back to default LoRA path." )
392395 else :
393396 max_logging .log (f"Auto-detected lora_module_path for model '{ model_name } ' (matched: '{ matched_key } ')" )
394397
395- raw_path = lora_configs .get (matched_key , "decoder/layers/.*(self_attention/(query|key|value|out)|mlp/(wi_0|wi_1|wo))" )
398+ raw_path = lora_configs .get (
399+ matched_key ,
400+ "decoder/layers/.*(self_attention/(query|key|value|out)|mlp/(wi_0|wi_1|wo))" ,
401+ )
396402
397403 # This regex makes the layer index optional, matching both scanned and unscanned layer paths
398404 # (e.g. 'layers/0/mlp/...' vs 'layers/mlp/...').
@@ -412,10 +418,15 @@ def _build_lora_provider(mt_config: pyconfig.HyperParameters) -> qwix.LoraProvid
412418 "alpha" : mt_config .lora .lora_alpha ,
413419 "dropout" : 0.0 ,
414420 }
415- max_logging .log (
416- f"LoRA configured: module_path={ lora_module_path } "
417- f"rank={ mt_config .lora .lora_rank } alpha={ mt_config .lora .lora_alpha } "
418- )
421+ if mt_config .lora .lora_tile_size is not None :
422+ lora_kwargs ["tile_size" ] = mt_config .lora .lora_tile_size
423+ if mt_config .lora .lora_weight_qtype is not None :
424+ lora_kwargs ["weight_qtype" ] = mt_config .lora .lora_weight_qtype
425+
426+ lora_type = "QLoRA" if mt_config .lora .lora_weight_qtype else "LoRA"
427+ args_str = " " .join (f"{ k } ={ v } " for k , v in lora_kwargs .items () if k != "dropout" )
428+ max_logging .log (f"{ lora_type } configured: { args_str } " )
429+
419430 return qwix .LoraProvider (** lora_kwargs )
420431
421432
@@ -448,7 +459,7 @@ def _verify_lora_parameters(lora_model: nnx.Module, mt_config: pyconfig.HyperPar
448459 matched_module_paths = []
449460 sample_module_paths = []
450461
451- for path , _ in nnx .iter_modules (lora_model ):
462+ for path , _ in nnx .iter_graph (lora_model ):
452463 module_path = "/" .join (str (p ) for p in path )
453464 if len (sample_module_paths ) < 100 :
454465 sample_module_paths .append (module_path )
@@ -469,6 +480,81 @@ def _verify_lora_parameters(lora_model: nnx.Module, mt_config: pyconfig.HyperPar
469480 )
470481
471482
483+ def _patch_qwix_for_maxtext (mesh , mt_config ):
484+ # pylint: disable=protected-access,import-outside-toplevel,redefined-outer-name,reimported,missing-function-docstring,consider-using-from-import
485+ import qwix ._src .flax_util as flax_util
486+ import qwix ._src .providers .ptq as ptq
487+ import jax .numpy as jnp
488+ from flax import nnx
489+
490+ # 1. PTQ patch
491+ original_get_intercept_map = ptq .PtqProvider .get_intercept_map
492+
493+ def patched_get_intercept_map (self ):
494+ mapping = original_get_intercept_map (self )
495+
496+ def intercept_asarray (a , dtype = None , order = None , ** kwargs ):
497+ if isinstance (a , nnx .State ) and "array" in a :
498+ a = a ["array" ]
499+ if isinstance (a , nnx .State ) and "qvalue" in a and "scale" in a :
500+ a = ptq .QArray (qvalue = a ["qvalue" ].value , scale = a ["scale" ].value )
501+
502+ if type (a ).__name__ in ("WithAux" , "QArray" ):
503+ return a
504+ return jnp .asarray (a , dtype = dtype , order = order , ** kwargs )
505+
506+ mapping ["jax.numpy.asarray" ] = intercept_asarray
507+ return mapping
508+
509+ ptq .PtqProvider .get_intercept_map = patched_get_intercept_map
510+
511+ # 2. find_param patch
512+ if not hasattr (flax_util , "_maxtext_find_param_patched" ):
513+
514+ def patched_find_param (x , ptq_array_type = None ):
515+ module = flax_util .get_current_module ()
516+ if module is None :
517+ return None
518+ candidates = {}
519+ if isinstance (module , nnx .Module ):
520+ array_types = nnx .Param | ptq_array_type if ptq_array_type else nnx .Param
521+ for name , node in module .__dict__ .items ():
522+ if isinstance (node , array_types ):
523+ candidates [name ] = node .value if isinstance (node , nnx .Param ) else node
524+ else :
525+ return (
526+ flax_util .find_param .__wrapped__ (x , ptq_array_type ) if hasattr (flax_util .find_param , "__wrapped__" ) else None
527+ )
528+
529+ candidates_by_id = {id (c ): n for n , c in candidates .items ()}
530+ for n , c in candidates .items ():
531+ if type (c ).__name__ == "WithAux" and hasattr (c , "array" ):
532+ candidates_by_id [id (c .array )] = n
533+
534+ if id (x ) in candidates_by_id :
535+ return candidates_by_id [id (x )]
536+
537+ if isinstance (x , jax .core .Tracer ) and hasattr (x , "parent" ):
538+ curr_x = x
539+ while True :
540+ if id (curr_x ) in candidates_by_id :
541+ return candidates_by_id [id (curr_x )]
542+ if curr_x .parent and len (curr_x .parent .in_tracers ) == 1 :
543+ curr_x = curr_x .parent .in_tracers [0 ]
544+ elif hasattr (curr_x , "get_const" ) and id (const := curr_x .get_const ()) in candidates_by_id :
545+ return candidates_by_id [id (const )]
546+ else :
547+ break
548+
549+ filtered = {n : c for n , c in candidates .items () if hasattr (c , "shape" ) and c .shape == getattr (x , "shape" , None )}
550+ if len (filtered ) == 1 :
551+ return list (filtered .keys ())[0 ]
552+ return None
553+
554+ flax_util .find_param = patched_find_param
555+ flax_util ._maxtext_find_param_patched = True
556+
557+
472558def apply_lora_to_model (
473559 model : nnx .Module ,
474560 mesh : Optional [jax .sharding .Mesh ],
@@ -485,6 +571,8 @@ def apply_lora_to_model(
485571
486572 # Dynamically detect and set LoRA rank before model creation if restoring
487573
574+ _patch_qwix_for_maxtext (mesh , mt_config )
575+
488576 lora_provider = _build_lora_provider (mt_config )
489577
490578 model_rngs = getattr (model .decoder , "rngs" , None )
0 commit comments