@@ -743,6 +743,38 @@ def _custom_mapping_to_lambda(mapping):
743743
744744 return all_rules
745745
746+ def _get_weight_bias (
747+ self ,
748+ module : torch .nn .Module ,
749+ dtype : torch .dtype = torch .float16 ,
750+ name_to_value : dict [str , torch .Tensor ] = {},
751+ ) -> dict [str , torch .Tensor ]:
752+ """Get the weight and bias of the module.
753+
754+ Args:
755+ module: The target module to get the weight and bias.
756+ dtype: The data type of the weight and bias.
757+ name_to_value: The dictionary to store the weight and bias.
758+
759+ Returns:
760+ The dictionary containing the weight and bias.
761+ """
762+ if hasattr (module , "weight" ) and module .weight is not None and module .weight .numel () > 0 :
763+ weight = module .weight .to (dtype ).cpu ()
764+ name_to_value ["weight" ] = weight
765+
766+ if hasattr (module , "bias" ) and module .bias is not None and module .bias .numel () > 0 :
767+ name_to_value ["bias" ] = module .bias .to (dtype ).cpu ()
768+
769+ if (
770+ hasattr (module , "expert_bias" )
771+ and module .expert_bias is not None
772+ and module .expert_bias .numel () > 0
773+ ):
774+ name_to_value ["expert_bias" ] = module .expert_bias .to (dtype ).cpu ()
775+
776+ return name_to_value
777+
746778 def _get_quantized_state (
747779 self ,
748780 module : torch .nn .Module ,
@@ -767,21 +799,12 @@ def _get_quantized_state(
767799 self .exclude_modules .append (prefix .removesuffix ("." ))
768800 block_size = get_weight_block_size (module )
769801
770- if hasattr (module , "weight" ) and module .weight is not None and module .weight .numel () > 0 :
771- weight = module .weight .to (dtype ).cpu ()
772- name_to_value ["weight" ] = weight
773- else :
774- return name_to_value , qformat , block_size
775-
776- if hasattr (module , "bias" ) and module .bias is not None and module .bias .numel () > 0 :
777- name_to_value ["bias" ] = module .bias .to (dtype ).cpu ()
802+ name_to_value = self ._get_weight_bias (module , dtype , name_to_value )
778803
779- if (
780- hasattr (module , "expert_bias" )
781- and module .expert_bias is not None
782- and module .expert_bias .numel () > 0
804+ if not (
805+ hasattr (module , "weight" ) and module .weight is not None and module .weight .numel () > 0
783806 ):
784- name_to_value [ "expert_bias" ] = module . expert_bias . to ( dtype ). cpu ()
807+ return name_to_value , qformat , block_size
785808
786809 if qformat == QUANTIZATION_NONE :
787810 return name_to_value , qformat , block_size
0 commit comments