diff --git a/examples/vllm_serve/fakequant_worker.py b/examples/vllm_serve/fakequant_worker.py index c56088fd7f..b88af9c72e 100644 --- a/examples/vllm_serve/fakequant_worker.py +++ b/examples/vllm_serve/fakequant_worker.py @@ -49,6 +49,7 @@ def _fakequant_run_prolog_worker(self) -> None: trust_remote_code = os.environ.get("TRUST_REMOTE_CODE", "false").lower() == "true" + tokenizer = AutoTokenizer.from_pretrained( self.model_runner.model_config.tokenizer, trust_remote_code=trust_remote_code ) diff --git a/examples/vllm_serve/vllm_reload_utils.py b/examples/vllm_serve/vllm_reload_utils.py index b67c92ae6e..2b59d1be2b 100644 --- a/examples/vllm_serve/vllm_reload_utils.py +++ b/examples/vllm_serve/vllm_reload_utils.py @@ -31,9 +31,32 @@ convert_to_quantized_model, restore_quantizer_state, ) +from modelopt.torch.quantization.nn import SequentialQuantizer, TensorQuantizer from modelopt.torch.quantization.utils import is_quantized +def _union_quantizer_keys_across_ranks(local_quantizer_keys: list[str]) -> set[str]: + """Union of quantizer key strings from every rank (same file on all ranks → identical to local).""" + local = set(local_quantizer_keys) + if not torch.distributed.is_available() or not torch.distributed.is_initialized(): + return local + if torch.distributed.get_world_size() <= 1: + return local + try: + world_size = torch.distributed.get_world_size() + gathered: list[list[str]] = [[] for _ in range(world_size)] + torch.distributed.all_gather_object(gathered, list(local_quantizer_keys)) + out: set[str] = set() + for g in gathered: + out.update(g) + return out + except Exception as e: + warnings.warn( + f"Could not all_gather quantizer key lists across ranks ({e}); using this rank's keys only." + ) + return local + + def _values_equal(v1: Any, v2: Any) -> bool: """Compare values, handling dicts with tensors.""" if isinstance(v1, dict) and isinstance(v2, dict): @@ -285,7 +308,7 @@ def filter_modelopt_state_quantizer_state_for_model( model: Model with quantizers (must already be converted) """ from modelopt.torch.quantization.conversion import quantizer_state - from modelopt.torch.quantization.nn import SequentialQuantizer, TensorQuantizer + from modelopt.torch.quantization.nn import TensorQuantizer from modelopt.torch.utils import get_unwrapped_name model_qstate = quantizer_state(model) @@ -435,24 +458,51 @@ def load_state_dict_from_path( # Count quant keys in checkpoint and model checkpoint_quant_keys = [key for key in saved_quant_dict if "quantizer" in key] model_quant_keys = [key for key in current_state_dict if "quantizer" in key] - for key in checkpoint_quant_keys: - if key not in model_quant_keys: - print(f"Key {key} not found in model state dict, but exists in checkpoint") + ckpt_key_set = set(checkpoint_quant_keys) + global_ckpt_key_set = _union_quantizer_keys_across_ranks(checkpoint_quant_keys) + # For weight quantizers absent from the checkpoint the weights were already fake-quantized + # at export time (amax folded into weights). Disable those quantizers so that fold_weight + # is a no-op for them. Non-weight keys missing on this rank but present on another rank's + # shard are omitted from global_missing (all_gather union of key strings). + missing_wq_module_paths: set[str] = set() + global_missing_non_wq: list[str] = [] for key in model_quant_keys: - if key not in checkpoint_quant_keys: - raise ValueError(f"Key {key} not found in checkpoint state dict, but exists in model") - - checkpoint_quant_count = len(checkpoint_quant_keys) - model_quant_count = len(model_quant_keys) - - # Ensure counts match - if checkpoint_quant_count != model_quant_count: + if key in ckpt_key_set: + continue + if "weight_quantizer" in key: + # Per-rank shard: only disable using this rank's checkpoint contents. + parts = key.split(".") + weight_quantizer_index = next( + (i for i, p in enumerate(parts) if p.endswith("weight_quantizer")), + None, + ) + if weight_quantizer_index is not None: + missing_wq_module_paths.add(".".join(parts[: weight_quantizer_index + 1])) + else: + raise ValueError( + f"Missing checkpoint key {key!r} looks like a weight quantizer, but no path " + "component ends with 'weight_quantizer'; cannot map to a module to disable." + ) + elif key not in global_ckpt_key_set: + global_missing_non_wq.append(key) + + if global_missing_non_wq: + keys = sorted(global_missing_non_wq) + n = len(keys) + sample, rest = keys[:8], n - 8 warnings.warn( - f"Mismatch in quantizer state key counts: checkpoint has {checkpoint_quant_count} " - f"quant keys but model has {model_quant_count} quantizer state keys. " - f"This can happen if the model is using PP." + f"{n} quantizer key(s) missing from every rank's checkpoint (after all_gather):" + f"{sample}{' ... (+{rest} more)' if rest > 0 else ''}" ) + for name, module in model.named_modules(): + if ( + name in missing_wq_module_paths + and isinstance(module, TensorQuantizer) + and hasattr(module, "disable") + ): + module.disable() + # Update quant values saved_quant_dict = process_state_dict_for_tp(saved_quant_dict, current_state_dict) for key, value in saved_quant_dict.items(): diff --git a/modelopt/torch/export/plugins/vllm_fakequant_megatron.py b/modelopt/torch/export/plugins/vllm_fakequant_megatron.py index 9a41ae2baf..c8e45be365 100644 --- a/modelopt/torch/export/plugins/vllm_fakequant_megatron.py +++ b/modelopt/torch/export/plugins/vllm_fakequant_megatron.py @@ -117,6 +117,10 @@ def _get_quantized_state( ) -> tuple[dict[str, torch.Tensor], str, int]: """Return a state_dict, quantization format, and block_size of the module. + The weight_quantizer is folded into the weight via fake-quantization + (quantize + dequantize), and its amax is not exported. The vLLM fakequant + reload path is expected to disable the weight quantizer when the amax is absent. + Args: module: The target module to perform real quantization. dtype: The default data type. @@ -133,14 +137,48 @@ def _get_quantized_state( block_size = 0 if hasattr(module, "weight") and module.weight is not None: - weight = module.weight.to(dtype).cpu() - name_to_value["weight"] = weight + weight = module.weight.to(dtype) + # Fold the weight_quantizer into the weight by applying fake-quantization + # (quantize then dequantize). The weight_quantizer amax is not exported; + # the vLLM fakequant reload path disables the weight quantizer when absent. + weight_quantizer = getattr(module, "weight_quantizer", None) + if weight_quantizer is not None and weight_quantizer.is_enabled: + with torch.no_grad(): + # NVFP4-like kernels may need CUDA; if weights are CPU after gather, run on + # CUDA then ``weight_quantizer.to`` back (full module round-trip). + quant_device = ( + torch.device("cuda", torch.cuda.current_device()) + if weight.device.type == "cpu" and torch.cuda.is_available() + else weight.device + ) + # TensorQuantizer does not expose nn.Module.device (custom __getattr__). + param_device = next(weight_quantizer.parameters(), None) + buf_device = next(weight_quantizer.buffers(), None) + wq_dev = ( + param_device.device + if param_device is not None + else (buf_device.device if buf_device is not None else torch.device("cpu")) + ) + need_move = wq_dev != quant_device + if need_move: + weight_quantizer.to(quant_device) + try: + weight = weight_quantizer(weight.to(quant_device)).to(dtype) + finally: + if need_move: + weight_quantizer.to(wq_dev) + name_to_value["weight"] = weight.cpu() else: return name_to_value, qformat, block_size if hasattr(module, "bias") and module.bias is not None: name_to_value["bias"] = module.bias.to(dtype).cpu() + + # Only save input/output quantizer state; weight_quantizer amax is not exported + # since it has been folded into the weight above. for name, param in get_quantizer_state_dict(module).items(): + if "weight_quantizer" in name: + continue for key, value in param.items(): name_to_value[name + "." + key] = value.to(dtype).cpu() return name_to_value, qformat, block_size