Skip to content

Commit 2c50141

Browse files
committed
minor
Signed-off-by: Kinjal Patel <kinjalpravin@nvidia.com>
1 parent 482f492 commit 2c50141

2 files changed

Lines changed: 12 additions & 6 deletions

File tree

modelopt/torch/export/plugins/vllm_fakequant_megatron.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -137,7 +137,9 @@ def _get_quantized_state(
137137
block_size = 0
138138
name_to_value = self._get_weight_bias(module, dtype, name_to_value)
139139
if "weight" in name_to_value:
140-
weight = name_to_value["weight"]
140+
# Use the original device (avoid the CPU round-trip introduced by _get_weight_bias;
141+
# fake-quantization runs on CUDA and the result is moved to CPU below).
142+
weight = module.weight.to(dtype)
141143
# Fold the weight_quantizer into the weight by applying fake-quantization
142144
# (quantize then dequantize). The weight_quantizer amax is not exported;
143145
# the vLLM fakequant reload path disables the weight quantizer when absent.

modelopt/torch/export/unified_export_megatron.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -747,18 +747,24 @@ def _get_weight_bias(
747747
self,
748748
module: torch.nn.Module,
749749
dtype: torch.dtype = torch.float16,
750-
name_to_value: dict[str, torch.Tensor] = {},
750+
name_to_value: dict[str, torch.Tensor] | None = None,
751751
) -> dict[str, torch.Tensor]:
752752
"""Get the weight and bias of the module.
753753
754754
Args:
755755
module: The target module to get the weight and bias.
756756
dtype: The data type of the weight and bias.
757-
name_to_value: The dictionary to store the weight and bias.
757+
name_to_value: The dictionary to store the weight and bias. A new dict is created
758+
if not provided.
758759
759760
Returns:
760761
The dictionary containing the weight and bias.
761762
"""
763+
if name_to_value is None:
764+
name_to_value = {}
765+
# numel() > 0 intentionally excludes zero-element weight tensors (e.g. MoE routing
766+
# layers whose weight is a placeholder) so callers can use "weight" in name_to_value
767+
# as a reliable guard without re-inspecting module.weight.
762768
if hasattr(module, "weight") and module.weight is not None and module.weight.numel() > 0:
763769
weight = module.weight.to(dtype).cpu()
764770
name_to_value["weight"] = weight
@@ -801,9 +807,7 @@ def _get_quantized_state(
801807

802808
name_to_value = self._get_weight_bias(module, dtype, name_to_value)
803809

804-
if not (
805-
hasattr(module, "weight") and module.weight is not None and module.weight.numel() > 0
806-
):
810+
if "weight" not in name_to_value:
807811
return name_to_value, qformat, block_size
808812

809813
if qformat == QUANTIZATION_NONE:

0 commit comments

Comments
 (0)