Skip to content

Commit d7e311f

Browse files
committed
update
Signed-off-by: Suguna Velury <178320438+sugunav14@users.noreply.github.com>
1 parent e8243e9 commit d7e311f

3 files changed

Lines changed: 10 additions & 16 deletions

File tree

examples/llm_ptq/example_utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -583,7 +583,7 @@ def get_model(
583583
model_kwargs = config_kwargs.copy()
584584
# Don't set torch_dtype for VILA models as they handle it explicitly in their builder
585585
if "vila" not in ckpt_path.lower():
586-
model_kwargs.setdefault("torch_dtype", "auto")
586+
model_kwargs.setdefault("dtype", "auto")
587587

588588
if "vila" in ckpt_path.lower():
589589
hf_vila = AutoModel.from_pretrained(
@@ -666,7 +666,7 @@ def has_pack_quantized_config(config):
666666
model_kwargs2 = model_kwargs.copy()
667667
if auto_model_module not in [AutoModelForCausalLM, AutoModel]:
668668
model_kwargs2.pop("trust_remote_code", None)
669-
model_kwargs2["torch_dtype"] = torch_dtype
669+
model_kwargs2["dtype"] = torch_dtype
670670
model_kwargs2.pop("max_memory", None)
671671
model = from_config(hf_config, **model_kwargs2)
672672

examples/llm_ptq/hf_ptq.py

Lines changed: 4 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -420,15 +420,10 @@ def load_model(args: argparse.Namespace):
420420
attn_implementation=args.attn_implementation,
421421
)
422422
else:
423-
if args.qformat in QUANT_CFG_CHOICES:
424-
quant_cfg = QUANT_CFG_CHOICES[args.qformat]
425-
elif hasattr(mtq, args.qformat):
426-
quant_cfg = getattr(mtq, args.qformat)
427-
else:
428-
raise AssertionError(
429-
f"Quantization format is not supported for low memory mode. "
430-
f"Supported formats: {QUANT_CFG_CHOICES.keys()}"
431-
)
423+
assert args.qformat in QUANT_CFG_CHOICES, (
424+
f"Quantization format is not supported for low memory mode. Supported formats: {QUANT_CFG_CHOICES.keys()}"
425+
)
426+
quant_cfg = QUANT_CFG_CHOICES[args.qformat]
432427
if args.kv_cache_qformat != "none":
433428
quant_cfg = mtq.utils.update_quant_cfg_with_kv_cache_quant(
434429
quant_cfg,

modelopt/torch/quantization/model_calib.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1665,7 +1665,10 @@ def gptq(
16651665

16661666
def _make_gptq_handle(name, m):
16671667
backend = getattr(m.weight_quantizer, "backend", None)
1668-
cls = _GPTQ_HELPER_REGISTRY.get(backend, GPTQHelper)
1668+
if backend is None:
1669+
cls = GPTQHelper
1670+
else:
1671+
cls = _GPTQ_HELPER_REGISTRY.get(backend, GPTQHelper)
16691672
return cls(m, name, offload_to_cpu=True)
16701673

16711674
gptq_handles = {name: _make_gptq_handle(name, m) for name, m in quantized_layers}
@@ -1685,10 +1688,6 @@ def _make_gptq_handle(name, m):
16851688
print_rank_0("Updating weights using GPTQ algorithm...")
16861689
for handle in gptq_handles.values():
16871690
handle.update_weights(block_size, perc_damp)
1688-
1689-
# Disable weight quantizer after running GPTQ update since weights are already QDQ'ed
1690-
if hasattr(handle.module, "weight_quantizer"):
1691-
handle.module.weight_quantizer.disable()
16921691
handle.free()
16931692
del gptq_handles
16941693

0 commit comments

Comments
 (0)