Skip to content

Commit fa9b770

Browse files
committed
minor
Signed-off-by: Kinjal Patel <kinjalpravin@nvidia.com>
1 parent 9b42a09 commit fa9b770

1 file changed

Lines changed: 5 additions & 0 deletions

File tree

modelopt/torch/export/plugins/vllm_fakequant_megatron.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,7 @@ def _get_quantized_state(
9999
self,
100100
module: torch.nn.Module,
101101
dtype: torch.dtype = torch.float16,
102+
prefix: str = "",
102103
) -> tuple[dict[str, torch.Tensor], str, int]:
103104
"""Return a state_dict, quantization format, and block_size of the module.
104105
@@ -111,6 +112,10 @@ def _get_quantized_state(
111112
"""
112113
name_to_value = {}
113114
qformat: str = self._get_quantization_format(module)
115+
if qformat is None and "norm" not in prefix:
116+
# Add exclude layers for vllm fakequant config. Note that if the prefix is not an empty
117+
# string then it usually ends with "." which needs to be removed.
118+
self.exclude_modules.append(prefix.removesuffix("."))
114119
block_size = 0
115120

116121
if hasattr(module, "weight") and module.weight is not None:

0 commit comments

Comments
 (0)