Skip to content

Commit 065cfca

Browse files
committed
fixed bug
Signed-off-by: Kinjal Patel <kinjalpravin@nvidia.com>
1 parent 2fef374 commit 065cfca

2 files changed

Lines changed: 39 additions & 19 deletions

File tree

modelopt/torch/export/plugins/vllm_fakequant_megatron.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -135,9 +135,9 @@ def _get_quantized_state(
135135
# string then it usually ends with "." which needs to be removed.
136136
self.exclude_modules.append(prefix.removesuffix("."))
137137
block_size = 0
138-
139-
if hasattr(module, "weight") and module.weight is not None:
140-
weight = module.weight.to(dtype)
138+
name_to_value = self._get_weight_bias(module, dtype, name_to_value)
139+
if "weight" in name_to_value:
140+
weight = name_to_value["weight"]
141141
# Fold the weight_quantizer into the weight by applying fake-quantization
142142
# (quantize then dequantize). The weight_quantizer amax is not exported;
143143
# the vLLM fakequant reload path disables the weight quantizer when absent.
@@ -171,9 +171,6 @@ def _get_quantized_state(
171171
else:
172172
return name_to_value, qformat, block_size
173173

174-
if hasattr(module, "bias") and module.bias is not None:
175-
name_to_value["bias"] = module.bias.to(dtype).cpu()
176-
177174
# Only save input/output quantizer state; weight_quantizer amax is not exported
178175
# since it has been folded into the weight above.
179176
for name, param in get_quantizer_state_dict(module).items():

modelopt/torch/export/unified_export_megatron.py

Lines changed: 36 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -743,6 +743,38 @@ def _custom_mapping_to_lambda(mapping):
743743

744744
return all_rules
745745

746+
def _get_weight_bias(
747+
self,
748+
module: torch.nn.Module,
749+
dtype: torch.dtype = torch.float16,
750+
name_to_value: dict[str, torch.Tensor] = {},
751+
) -> dict[str, torch.Tensor]:
752+
"""Get the weight and bias of the module.
753+
754+
Args:
755+
module: The target module to get the weight and bias.
756+
dtype: The data type of the weight and bias.
757+
name_to_value: The dictionary to store the weight and bias.
758+
759+
Returns:
760+
The dictionary containing the weight and bias.
761+
"""
762+
if hasattr(module, "weight") and module.weight is not None and module.weight.numel() > 0:
763+
weight = module.weight.to(dtype).cpu()
764+
name_to_value["weight"] = weight
765+
766+
if hasattr(module, "bias") and module.bias is not None and module.bias.numel() > 0:
767+
name_to_value["bias"] = module.bias.to(dtype).cpu()
768+
769+
if (
770+
hasattr(module, "expert_bias")
771+
and module.expert_bias is not None
772+
and module.expert_bias.numel() > 0
773+
):
774+
name_to_value["expert_bias"] = module.expert_bias.to(dtype).cpu()
775+
776+
return name_to_value
777+
746778
def _get_quantized_state(
747779
self,
748780
module: torch.nn.Module,
@@ -767,21 +799,12 @@ def _get_quantized_state(
767799
self.exclude_modules.append(prefix.removesuffix("."))
768800
block_size = get_weight_block_size(module)
769801

770-
if hasattr(module, "weight") and module.weight is not None and module.weight.numel() > 0:
771-
weight = module.weight.to(dtype).cpu()
772-
name_to_value["weight"] = weight
773-
else:
774-
return name_to_value, qformat, block_size
775-
776-
if hasattr(module, "bias") and module.bias is not None and module.bias.numel() > 0:
777-
name_to_value["bias"] = module.bias.to(dtype).cpu()
802+
name_to_value = self._get_weight_bias(module, dtype, name_to_value)
778803

779-
if (
780-
hasattr(module, "expert_bias")
781-
and module.expert_bias is not None
782-
and module.expert_bias.numel() > 0
804+
if not (
805+
hasattr(module, "weight") and module.weight is not None and module.weight.numel() > 0
783806
):
784-
name_to_value["expert_bias"] = module.expert_bias.to(dtype).cpu()
807+
return name_to_value, qformat, block_size
785808

786809
if qformat == QUANTIZATION_NONE:
787810
return name_to_value, qformat, block_size

0 commit comments

Comments
 (0)