@@ -959,6 +959,36 @@ def _export_diffusers_checkpoint(
959959 print (f"Export complete. Saved to: { export_dir } " )
960960
961961
962+ def _revert_weight_conversion_noop (model : Any , state_dict : dict ) -> dict :
963+ """No-op replacement for transformers' revert_weight_conversion."""
964+ return state_dict
965+
966+
967+ def _patch_revert_weight_conversion () -> list [tuple [Any , Any ]]:
968+ """Patch revert_weight_conversion in transformers to avoid IndexError on scalar tensors."""
969+ import importlib
970+
971+ patches : list [tuple [Any , Any ]] = []
972+ for mod_path in [
973+ "transformers.core_model_loading" ,
974+ "transformers.modeling_utils" ,
975+ ]:
976+ try :
977+ mod = importlib .import_module (mod_path )
978+ if hasattr (mod , "revert_weight_conversion" ):
979+ patches .append ((mod , getattr (mod , "revert_weight_conversion" )))
980+ setattr (mod , "revert_weight_conversion" , _revert_weight_conversion_noop )
981+ except (ImportError , AttributeError ):
982+ pass
983+ return patches
984+
985+
986+ def _unpatch_revert_weight_conversion (patches : list [tuple [Any , Any ]]) -> None :
987+ """Restore the original revert_weight_conversion functions."""
988+ for mod , original in patches :
989+ mod .revert_weight_conversion = original
990+
991+
962992def export_hf_checkpoint (
963993 model : Any ,
964994 dtype : torch .dtype | None = None ,
@@ -1022,21 +1052,7 @@ def export_hf_checkpoint(
10221052 # quantized state dicts (scalar scale tensors have 0 dimensions, causing IndexError).
10231053 # We must patch both the source module and the importing module since
10241054 # modeling_utils does `from core_model_loading import revert_weight_conversion`.
1025- _patches = []
1026- _noop = lambda model , state_dict : state_dict
1027- for _mod_path in [
1028- "transformers.core_model_loading" ,
1029- "transformers.modeling_utils" ,
1030- ]:
1031- try :
1032- import importlib
1033-
1034- _mod = importlib .import_module (_mod_path )
1035- if hasattr (_mod , "revert_weight_conversion" ):
1036- _patches .append ((_mod , getattr (_mod , "revert_weight_conversion" )))
1037- setattr (_mod , "revert_weight_conversion" , _noop )
1038- except (ImportError , AttributeError ):
1039- pass
1055+ _patches = _patch_revert_weight_conversion ()
10401056
10411057 try :
10421058 model .save_pretrained (
@@ -1045,8 +1061,7 @@ def export_hf_checkpoint(
10451061 save_modelopt_state = save_modelopt_state ,
10461062 )
10471063 finally :
1048- for _mod , _original in _patches :
1049- _mod .revert_weight_conversion = _original
1064+ _unpatch_revert_weight_conversion (_patches )
10501065
10511066 original_config = f"{ export_dir } /config.json"
10521067 config_data = {}
0 commit comments