@@ -747,18 +747,24 @@ def _get_weight_bias(
747747 self ,
748748 module : torch .nn .Module ,
749749 dtype : torch .dtype = torch .float16 ,
750- name_to_value : dict [str , torch .Tensor ] = {} ,
750+ name_to_value : dict [str , torch .Tensor ] | None = None ,
751751 ) -> dict [str , torch .Tensor ]:
752752 """Get the weight and bias of the module.
753753
754754 Args:
755755 module: The target module to get the weight and bias.
756756 dtype: The data type of the weight and bias.
757- name_to_value: The dictionary to store the weight and bias.
757+ name_to_value: The dictionary to store the weight and bias. A new dict is created
758+ if not provided.
758759
759760 Returns:
760761 The dictionary containing the weight and bias.
761762 """
763+ if name_to_value is None :
764+ name_to_value = {}
765+ # numel() > 0 intentionally excludes zero-element weight tensors (e.g. MoE routing
766+ # layers whose weight is a placeholder) so callers can use "weight" in name_to_value
767+ # as a reliable guard without re-inspecting module.weight.
762768 if hasattr (module , "weight" ) and module .weight is not None and module .weight .numel () > 0 :
763769 weight = module .weight .to (dtype ).cpu ()
764770 name_to_value ["weight" ] = weight
@@ -801,9 +807,7 @@ def _get_quantized_state(
801807
802808 name_to_value = self ._get_weight_bias (module , dtype , name_to_value )
803809
804- if not (
805- hasattr (module , "weight" ) and module .weight is not None and module .weight .numel () > 0
806- ):
810+ if "weight" not in name_to_value :
807811 return name_to_value , qformat , block_size
808812
809813 if qformat == QUANTIZATION_NONE :
0 commit comments