@@ -99,6 +99,21 @@ def _maybe_print_module_tree(model) -> None:
9999 print_module_tree (model = model )
100100
101101
102+ def _convert_model_with_defuser (cls , model , cleanup_original : bool ) -> bool :
103+ converted = defuser .convert_model (model , cleanup_original = cleanup_original )
104+
105+ defuser_module_paths = getattr (cls , "defuser_module_paths" , ())
106+ if defuser_module_paths :
107+ for module_path in defuser_module_paths :
108+ module , _ = get_module_by_name_prefix (model , module_path )
109+ if module is None :
110+ log .warn ("Loader: defuser module path `%s` was not found." , module_path )
111+ continue
112+ converted = defuser .convert_model (module , cleanup_original = cleanup_original ) or converted
113+
114+ return converted
115+
116+
102117def _supports_flash_attn_2 (config : PretrainedConfig ) -> bool :
103118 """Detect whether the resolved HF architecture exposes FA2 kernels."""
104119
@@ -727,12 +742,12 @@ def skip(*args, **kwargs):
727742 )
728743 if getattr (model , "config" , None ) is config :
729744 model .config = copy .deepcopy (config )
730- defuser . convert_model ( model , cleanup_original = False )
745+ _convert_model_with_defuser ( cls , model , cleanup_original = False )
731746 model ._model_init_kwargs = fallback_init_kwargs
732747 _maybe_print_module_tree (model = model )
733748 turtle_model = None
734749 else :
735- defuser . convert_model ( model , cleanup_original = False )
750+ _convert_model_with_defuser ( cls , model , cleanup_original = False )
736751 shell_model_init_kwargs = dict (model_init_kwargs_without_internal )
737752 shell_model_init_kwargs .update (hf_gguf_load_kwargs )
738753 model ._model_init_kwargs = shell_model_init_kwargs
@@ -768,7 +783,7 @@ def skip(*args, **kwargs):
768783 )
769784 if getattr (model , "config" , None ) is config :
770785 model .config = copy .deepcopy (config )
771- defuser . convert_model ( model , cleanup_original = False )
786+ _convert_model_with_defuser ( cls , model , cleanup_original = False )
772787 direct_model_init_kwargs = dict (model_init_kwargs_without_internal )
773788 direct_model_init_kwargs .update (hf_gguf_load_kwargs )
774789 model ._model_init_kwargs = direct_model_init_kwargs
@@ -1188,7 +1203,7 @@ def from_quantized(
11881203 )
11891204 else :
11901205 raise
1191- defuser . convert_model ( model , cleanup_original = True )
1206+ _convert_model_with_defuser ( cls , model , cleanup_original = True )
11921207 model .checkpoint_file_name = model_save_name
11931208 if native_gguf_qspec is not None :
11941209 gguf_tensor_key_mapping = _build_gguf_tensor_key_mapping (model , config )
0 commit comments