@@ -964,22 +964,31 @@ def _revert_weight_conversion_noop(model: Any, state_dict: dict) -> dict:
964964 return state_dict
965965
966966
967- def _patch_revert_weight_conversion ( ) -> list [ tuple [Any , Any ]] :
968- """Patch revert_weight_conversion in transformers to avoid IndexError on scalar tensors ."""
967+ def _try_patch_module ( mod_path : str ) -> tuple [Any , Any ] | None :
968+ """Try to patch revert_weight_conversion in a single module ."""
969969 import importlib
970970
971+ try :
972+ mod = importlib .import_module (mod_path )
973+ if hasattr (mod , "revert_weight_conversion" ):
974+ original = getattr (mod , "revert_weight_conversion" )
975+ setattr (mod , "revert_weight_conversion" , _revert_weight_conversion_noop )
976+ return (mod , original )
977+ except (ImportError , AttributeError ):
978+ pass
979+ return None
980+
981+
982+ def _patch_revert_weight_conversion () -> list [tuple [Any , Any ]]:
983+ """Patch revert_weight_conversion in transformers to avoid IndexError on scalar tensors."""
971984 patches : list [tuple [Any , Any ]] = []
972985 for mod_path in [
973986 "transformers.core_model_loading" ,
974987 "transformers.modeling_utils" ,
975988 ]:
976- try :
977- mod = importlib .import_module (mod_path )
978- if hasattr (mod , "revert_weight_conversion" ):
979- patches .append ((mod , getattr (mod , "revert_weight_conversion" )))
980- setattr (mod , "revert_weight_conversion" , _revert_weight_conversion_noop )
981- except (ImportError , AttributeError ):
982- pass
989+ result = _try_patch_module (mod_path )
990+ if result is not None :
991+ patches .append (result )
983992 return patches
984993
985994
0 commit comments