Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions examples/vllm_serve/fakequant_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@

def _fakequant_run_prolog_worker(self) -> None:
trust_remote_code = os.environ.get("TRUST_REMOTE_CODE", "false").lower() == "true"

tokenizer = AutoTokenizer.from_pretrained(
self.model_runner.model_config.tokenizer, trust_remote_code=trust_remote_code
)
Expand Down
80 changes: 65 additions & 15 deletions examples/vllm_serve/vllm_reload_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,9 +31,32 @@
convert_to_quantized_model,
restore_quantizer_state,
)
from modelopt.torch.quantization.nn import SequentialQuantizer, TensorQuantizer
from modelopt.torch.quantization.utils import is_quantized


def _union_quantizer_keys_across_ranks(local_quantizer_keys: list[str]) -> set[str]:
"""Union of quantizer key strings from every rank (same file on all ranks → identical to local)."""
local = set(local_quantizer_keys)
if not torch.distributed.is_available() or not torch.distributed.is_initialized():
return local
if torch.distributed.get_world_size() <= 1:
return local
try:
world_size = torch.distributed.get_world_size()
gathered: list[list[str]] = [[] for _ in range(world_size)]
torch.distributed.all_gather_object(gathered, list(local_quantizer_keys))
out: set[str] = set()
for g in gathered:
out.update(g)
return out
except Exception as e:
warnings.warn(
f"Could not all_gather quantizer key lists across ranks ({e}); using this rank's keys only."
)
return local


def _values_equal(v1: Any, v2: Any) -> bool:
"""Compare values, handling dicts with tensors."""
if isinstance(v1, dict) and isinstance(v2, dict):
Expand Down Expand Up @@ -285,7 +308,7 @@ def filter_modelopt_state_quantizer_state_for_model(
model: Model with quantizers (must already be converted)
"""
from modelopt.torch.quantization.conversion import quantizer_state
from modelopt.torch.quantization.nn import SequentialQuantizer, TensorQuantizer
from modelopt.torch.quantization.nn import TensorQuantizer
from modelopt.torch.utils import get_unwrapped_name

model_qstate = quantizer_state(model)
Expand Down Expand Up @@ -435,24 +458,51 @@ def load_state_dict_from_path(
# Count quant keys in checkpoint and model
checkpoint_quant_keys = [key for key in saved_quant_dict if "quantizer" in key]
model_quant_keys = [key for key in current_state_dict if "quantizer" in key]
for key in checkpoint_quant_keys:
if key not in model_quant_keys:
print(f"Key {key} not found in model state dict, but exists in checkpoint")
ckpt_key_set = set(checkpoint_quant_keys)
global_ckpt_key_set = _union_quantizer_keys_across_ranks(checkpoint_quant_keys)
# For weight quantizers absent from the checkpoint the weights were already fake-quantized
# at export time (amax folded into weights). Disable those quantizers so that fold_weight
# is a no-op for them. Non-weight keys missing on this rank but present on another rank's
# shard are omitted from global_missing (all_gather union of key strings).
missing_wq_module_paths: set[str] = set()
global_missing_non_wq: list[str] = []
for key in model_quant_keys:
if key not in checkpoint_quant_keys:
raise ValueError(f"Key {key} not found in checkpoint state dict, but exists in model")

checkpoint_quant_count = len(checkpoint_quant_keys)
model_quant_count = len(model_quant_keys)

# Ensure counts match
if checkpoint_quant_count != model_quant_count:
if key in ckpt_key_set:
continue
if "weight_quantizer" in key:
# Per-rank shard: only disable using this rank's checkpoint contents.
parts = key.split(".")
weight_quantizer_index = next(
(i for i, p in enumerate(parts) if p.endswith("weight_quantizer")),
None,
)
if weight_quantizer_index is not None:
missing_wq_module_paths.add(".".join(parts[: weight_quantizer_index + 1]))
else:
raise ValueError(
f"Missing checkpoint key {key!r} looks like a weight quantizer, but no path "
"component ends with 'weight_quantizer'; cannot map to a module to disable."
)
elif key not in global_ckpt_key_set:
global_missing_non_wq.append(key)

if global_missing_non_wq:
keys = sorted(global_missing_non_wq)
n = len(keys)
sample, rest = keys[:8], n - 8
warnings.warn(
f"Mismatch in quantizer state key counts: checkpoint has {checkpoint_quant_count} "
f"quant keys but model has {model_quant_count} quantizer state keys. "
f"This can happen if the model is using PP."
f"{n} quantizer key(s) missing from every rank's checkpoint (after all_gather):"
f"{sample}{' ... (+{rest} more)' if rest > 0 else ''}"
)

for name, module in model.named_modules():
if (
name in missing_wq_module_paths
and isinstance(module, TensorQuantizer)
and hasattr(module, "disable")
):
module.disable()
Comment thread
kinjalpatel27 marked this conversation as resolved.

# Update quant values
saved_quant_dict = process_state_dict_for_tp(saved_quant_dict, current_state_dict)
for key, value in saved_quant_dict.items():
Expand Down
42 changes: 40 additions & 2 deletions modelopt/torch/export/plugins/vllm_fakequant_megatron.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,10 @@ def _get_quantized_state(
) -> tuple[dict[str, torch.Tensor], str, int]:
"""Return a state_dict, quantization format, and block_size of the module.

The weight_quantizer is folded into the weight via fake-quantization
(quantize + dequantize), and its amax is not exported. The vLLM fakequant
reload path is expected to disable the weight quantizer when the amax is absent.

Args:
module: The target module to perform real quantization.
dtype: The default data type.
Expand All @@ -133,14 +137,48 @@ def _get_quantized_state(
block_size = 0

if hasattr(module, "weight") and module.weight is not None:
weight = module.weight.to(dtype).cpu()
name_to_value["weight"] = weight
weight = module.weight.to(dtype)
# Fold the weight_quantizer into the weight by applying fake-quantization
# (quantize then dequantize). The weight_quantizer amax is not exported;
Comment thread
kinjalpatel27 marked this conversation as resolved.
# the vLLM fakequant reload path disables the weight quantizer when absent.
weight_quantizer = getattr(module, "weight_quantizer", None)
if weight_quantizer is not None and weight_quantizer.is_enabled:
with torch.no_grad():
# NVFP4-like kernels may need CUDA; if weights are CPU after gather, run on
# CUDA then ``weight_quantizer.to`` back (full module round-trip).
quant_device = (
torch.device("cuda", torch.cuda.current_device())
if weight.device.type == "cpu" and torch.cuda.is_available()
else weight.device
)
# TensorQuantizer does not expose nn.Module.device (custom __getattr__).
param_device = next(weight_quantizer.parameters(), None)
buf_device = next(weight_quantizer.buffers(), None)
wq_dev = (
param_device.device
if param_device is not None
else (buf_device.device if buf_device is not None else torch.device("cpu"))
)
need_move = wq_dev != quant_device
if need_move:
weight_quantizer.to(quant_device)
try:
weight = weight_quantizer(weight.to(quant_device)).to(dtype)
finally:
if need_move:
weight_quantizer.to(wq_dev)
name_to_value["weight"] = weight.cpu()
else:
return name_to_value, qformat, block_size

if hasattr(module, "bias") and module.bias is not None:
name_to_value["bias"] = module.bias.to(dtype).cpu()

# Only save input/output quantizer state; weight_quantizer amax is not exported
# since it has been folded into the weight above.
for name, param in get_quantizer_state_dict(module).items():
if "weight_quantizer" in name:
continue
for key, value in param.items():
name_to_value[name + "." + key] = value.to(dtype).cpu()
return name_to_value, qformat, block_size
Expand Down
Loading