|
24 | 24 | from vllm.v1.core.sched.output import CachedRequestData, NewRequestData, SchedulerOutput |
25 | 25 |
|
26 | 26 | import modelopt.torch.quantization as mtq |
| 27 | +from modelopt.recipe import ModelOptPTQRecipe, load_recipe |
27 | 28 |
|
28 | 29 |
|
29 | 30 | def _create_new_data_cls(data_cls, **kwargs): |
@@ -141,22 +142,31 @@ def update_kv_cfg_for_mla(model: torch.nn.Module, kv_quant_cfg: list) -> list: |
141 | 142 | def get_quant_config(quant_config: dict[str, Any], model: Any) -> dict[str, Any]: |
142 | 143 | import copy |
143 | 144 |
|
144 | | - quant_cfg = ( |
145 | | - copy.deepcopy(getattr(mtq, quant_config["quant_cfg"])) if quant_config["quant_cfg"] else {} |
146 | | - ) |
147 | | - quant_kv_cfg = ( |
148 | | - copy.deepcopy(getattr(mtq, quant_config["kv_quant_cfg"])) |
149 | | - if quant_config["kv_quant_cfg"] |
150 | | - else {} |
151 | | - ) |
| 145 | + if quant_config["recipe_path"]: |
| 146 | + recipe = load_recipe(quant_config["recipe_path"]) |
| 147 | + assert isinstance(recipe, ModelOptPTQRecipe), ( |
| 148 | + f"Expected PTQ recipe, but got {type(recipe).__name__} from {quant_config['recipe_path']}" |
| 149 | + ) |
| 150 | + quant_cfg = recipe.quantize |
| 151 | + else: |
| 152 | + quant_cfg = ( |
| 153 | + copy.deepcopy(getattr(mtq, quant_config["quant_cfg"])) |
| 154 | + if quant_config["quant_cfg"] |
| 155 | + else {} |
| 156 | + ) |
| 157 | + quant_kv_cfg = ( |
| 158 | + copy.deepcopy(getattr(mtq, quant_config["kv_quant_cfg"])) |
| 159 | + if quant_config["kv_quant_cfg"] |
| 160 | + else {} |
| 161 | + ) |
152 | 162 |
|
153 | | - # Check if model has MLA and update KV config accordingly |
154 | | - if quant_kv_cfg: |
155 | | - quant_kv_cfg["quant_cfg"] = update_kv_cfg_for_mla(model, quant_kv_cfg["quant_cfg"]) |
| 163 | + # Check if model has MLA and update KV config accordingly |
| 164 | + if quant_kv_cfg: |
| 165 | + quant_kv_cfg["quant_cfg"] = update_kv_cfg_for_mla(model, quant_kv_cfg["quant_cfg"]) |
156 | 166 |
|
157 | | - if quant_kv_cfg: |
158 | | - quant_cfg = mtq.utils.update_quant_cfg_with_kv_cache_quant( |
159 | | - quant_cfg, quant_kv_cfg["quant_cfg"] |
160 | | - ) |
| 167 | + if quant_kv_cfg: |
| 168 | + quant_cfg = mtq.utils.update_quant_cfg_with_kv_cache_quant( |
| 169 | + quant_cfg, quant_kv_cfg["quant_cfg"] |
| 170 | + ) |
161 | 171 |
|
162 | 172 | return quant_cfg |
0 commit comments