Skip to content

Commit 0d33159

Browse files
authored
Merge branch 'main' into shengliangx/normalize-yaml-ext
2 parents dcc0787 + 73be810 commit 0d33159

File tree

4 files changed

+30
-16
lines changed

4 files changed

+30
-16
lines changed

examples/vllm_serve/README.md

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ You can either edit the `quant_config` dictionary in `vllm_serve_fakequant.py`,
2828
| QUANT_FILE_PATH | Optional path to exported quantizer state dict `quantizer_state.pth` | None |
2929
| MODELOPT_STATE_PATH | Optional path to exported `vllm_fq_modelopt_state.pth` (restores quantizer state and parameters) | None |
3030
| CALIB_BATCH_SIZE | Calibration batch size | 1 |
31+
| RECIPE_PATH | Optional path to a ModelOpt PTQ recipe YAML | None |
3132

3233
Set these variables in your shell or Docker environment as needed to customize calibration.
3334

@@ -65,7 +66,7 @@ Step 1: export the model with bf16 weights and quantizer state. To export the mo
6566
```bash
6667
python ../llm_ptq/hf_ptq.py \
6768
--pyt_ckpt_path <MODEL_PATH> \
68-
--qformat nvfp4 \
69+
--recipe <PATH_TO_RECIPE> \
6970
--calib_size 512 \
7071
--export_path <EXPORT_DIR> \
7172
--vllm_fakequant_export \

examples/vllm_serve/fakequant_worker.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@
4343
"quant_file_path": os.environ.get("QUANT_FILE_PATH", None),
4444
"modelopt_state_path": os.environ.get("MODELOPT_STATE_PATH", None),
4545
"calib_batch_size": int(os.environ.get("CALIB_BATCH_SIZE", 1)),
46+
"recipe_path": os.environ.get("RECIPE_PATH", None),
4647
}
4748

4849

@@ -138,6 +139,7 @@ def compile_or_warm_up_model(self) -> None:
138139
quant_config["quant_cfg"]
139140
or quant_config["kv_quant_cfg"]
140141
or quant_config["modelopt_state_path"]
142+
or quant_config["recipe_path"]
141143
):
142144
_fakequant_run_prolog_worker(self)
143145
super().compile_or_warm_up_model()

examples/vllm_serve/vllm_ptq_utils.py

Lines changed: 25 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
from vllm.v1.core.sched.output import CachedRequestData, NewRequestData, SchedulerOutput
2525

2626
import modelopt.torch.quantization as mtq
27+
from modelopt.recipe import ModelOptPTQRecipe, load_recipe
2728

2829

2930
def _create_new_data_cls(data_cls, **kwargs):
@@ -141,22 +142,31 @@ def update_kv_cfg_for_mla(model: torch.nn.Module, kv_quant_cfg: list) -> list:
141142
def get_quant_config(quant_config: dict[str, Any], model: Any) -> dict[str, Any]:
142143
import copy
143144

144-
quant_cfg = (
145-
copy.deepcopy(getattr(mtq, quant_config["quant_cfg"])) if quant_config["quant_cfg"] else {}
146-
)
147-
quant_kv_cfg = (
148-
copy.deepcopy(getattr(mtq, quant_config["kv_quant_cfg"]))
149-
if quant_config["kv_quant_cfg"]
150-
else {}
151-
)
145+
if quant_config["recipe_path"]:
146+
recipe = load_recipe(quant_config["recipe_path"])
147+
assert isinstance(recipe, ModelOptPTQRecipe), (
148+
f"Expected PTQ recipe, but got {type(recipe).__name__} from {quant_config['recipe_path']}"
149+
)
150+
quant_cfg = recipe.quantize
151+
else:
152+
quant_cfg = (
153+
copy.deepcopy(getattr(mtq, quant_config["quant_cfg"]))
154+
if quant_config["quant_cfg"]
155+
else {}
156+
)
157+
quant_kv_cfg = (
158+
copy.deepcopy(getattr(mtq, quant_config["kv_quant_cfg"]))
159+
if quant_config["kv_quant_cfg"]
160+
else {}
161+
)
152162

153-
# Check if model has MLA and update KV config accordingly
154-
if quant_kv_cfg:
155-
quant_kv_cfg["quant_cfg"] = update_kv_cfg_for_mla(model, quant_kv_cfg["quant_cfg"])
163+
# Check if model has MLA and update KV config accordingly
164+
if quant_kv_cfg:
165+
quant_kv_cfg["quant_cfg"] = update_kv_cfg_for_mla(model, quant_kv_cfg["quant_cfg"])
156166

157-
if quant_kv_cfg:
158-
quant_cfg = mtq.utils.update_quant_cfg_with_kv_cache_quant(
159-
quant_cfg, quant_kv_cfg["quant_cfg"]
160-
)
167+
if quant_kv_cfg:
168+
quant_cfg = mtq.utils.update_quant_cfg_with_kv_cache_quant(
169+
quant_cfg, quant_kv_cfg["quant_cfg"]
170+
)
161171

162172
return quant_cfg

examples/vllm_serve/vllm_serve_fakequant.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,7 @@
7878
"KV_QUANT_CFG",
7979
"MODELOPT_STATE_PATH",
8080
"CALIB_BATCH_SIZE",
81+
"RECIPE_PATH",
8182
"TRUST_REMOTE_CODE",
8283
}
8384

0 commit comments

Comments
 (0)