Skip to content

Commit c3f93ef

Browse files
committed
latest tested on Qwen3-8B
Signed-off-by: Suguna Velury <178320438+sugunav14@users.noreply.github.com>
1 parent d101aba commit c3f93ef

5 files changed

Lines changed: 438 additions & 63 deletions

File tree

examples/llm_ptq/example_utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -583,7 +583,7 @@ def get_model(
583583
model_kwargs = config_kwargs.copy()
584584
# Don't set torch_dtype for VILA models as they handle it explicitly in their builder
585585
if "vila" not in ckpt_path.lower():
586-
model_kwargs.setdefault("dtype", "auto")
586+
model_kwargs.setdefault("torch_dtype", "auto")
587587

588588
if "vila" in ckpt_path.lower():
589589
hf_vila = AutoModel.from_pretrained(
@@ -666,7 +666,7 @@ def has_pack_quantized_config(config):
666666
model_kwargs2 = model_kwargs.copy()
667667
if auto_model_module not in [AutoModelForCausalLM, AutoModel]:
668668
model_kwargs2.pop("trust_remote_code", None)
669-
model_kwargs2["dtype"] = torch_dtype
669+
model_kwargs2["torch_dtype"] = torch_dtype
670670
model_kwargs2.pop("max_memory", None)
671671
model = from_config(hf_config, **model_kwargs2)
672672

examples/llm_ptq/hf_ptq.py

Lines changed: 65 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,7 @@
7171
)
7272
from modelopt.torch.utils.dataset_utils import (
7373
create_forward_loop,
74+
get_calib_and_holdout_dataloaders,
7475
get_dataset_dataloader,
7576
get_max_batch_size,
7677
get_supported_datasets,
@@ -203,9 +204,10 @@ def make_calib_dataloader(
203204
tokenizer: PreTrainedTokenizerBase | None,
204205
device: torch.device,
205206
model_type: str | None,
206-
) -> tuple[DataLoader | _DeviceDataLoader, str | None]:
207+
) -> tuple[DataLoader | _DeviceDataLoader, str | None, Path | None]:
207208
calib_dataloader = None
208209
first_text_speech_dataset = None
210+
holdout_path = None
209211
if args.specdec_offline_dataset is not None:
210212
offline_data_path = Path(args.specdec_offline_dataset)
211213
dumped_files = sorted(str(p) for p in offline_data_path.glob("*.pt"))
@@ -283,15 +285,29 @@ def make_calib_dataloader(
283285
include_labels = (
284286
args.auto_quantize_bits is not None and args.auto_quantize_method == "gradient"
285287
)
286-
calib_dataloader = get_dataset_dataloader(
287-
dataset_name=args.dataset,
288-
tokenizer=tokenizer,
289-
batch_size=args.batch_size,
290-
num_samples=args.calib_size,
291-
device=device,
292-
include_labels=include_labels,
293-
)
294-
return calib_dataloader, first_text_speech_dataset
288+
289+
if args.holdout_size > 0:
290+
calib_dataloader, holdout_path = get_calib_and_holdout_dataloaders(
291+
dataset_name=args.dataset,
292+
tokenizer=tokenizer,
293+
batch_size=args.batch_size,
294+
calib_size=args.calib_size,
295+
holdout_size=args.holdout_size,
296+
max_sample_length=args.calib_seq,
297+
device=device,
298+
include_labels=include_labels,
299+
save_dir=args.calib_data_dir,
300+
)
301+
else:
302+
calib_dataloader = get_dataset_dataloader(
303+
dataset_name=args.dataset,
304+
tokenizer=tokenizer,
305+
batch_size=args.batch_size,
306+
num_samples=args.calib_size,
307+
device=device,
308+
include_labels=include_labels,
309+
)
310+
return calib_dataloader, first_text_speech_dataset, holdout_path
295311

296312

297313
def auto_quantize(
@@ -419,10 +435,15 @@ def load_model(args: argparse.Namespace):
419435
attn_implementation=args.attn_implementation,
420436
)
421437
else:
422-
assert args.qformat in QUANT_CFG_CHOICES, (
423-
f"Quantization format is not supported for low memory mode. Supported formats: {QUANT_CFG_CHOICES.keys()}"
424-
)
425-
quant_cfg = QUANT_CFG_CHOICES[args.qformat]
438+
if args.qformat in QUANT_CFG_CHOICES:
439+
quant_cfg = QUANT_CFG_CHOICES[args.qformat]
440+
elif hasattr(mtq, args.qformat):
441+
quant_cfg = getattr(mtq, args.qformat)
442+
else:
443+
raise AssertionError(
444+
f"Quantization format is not supported for low memory mode. "
445+
f"Supported formats: {QUANT_CFG_CHOICES.keys()}"
446+
)
426447
if args.kv_cache_qformat != "none":
427448
quant_cfg = mtq.utils.update_quant_cfg_with_kv_cache_quant(
428449
quant_cfg,
@@ -1028,7 +1049,7 @@ def quantize_main(
10281049

10291050
print(f"Use calib batch_size {args.batch_size}")
10301051

1031-
calib_dataloader, first_text_speech_dataset = make_calib_dataloader(
1052+
calib_dataloader, first_text_speech_dataset, holdout_path = make_calib_dataloader(
10321053
args, language_model, processor, tokenizer, device, model_type
10331054
)
10341055

@@ -1066,10 +1087,14 @@ def quantize_main(
10661087
"Plain quantization supports only one quantization format."
10671088
)
10681089

1069-
assert args.qformat in QUANT_CFG_CHOICES, (
1070-
f"Unsupported quantization format: {args.qformat}, choices are: {list(QUANT_CFG_CHOICES.keys())}"
1071-
)
1072-
quant_cfg = QUANT_CFG_CHOICES[args.qformat]
1090+
if args.qformat in QUANT_CFG_CHOICES:
1091+
quant_cfg = QUANT_CFG_CHOICES[args.qformat]
1092+
elif hasattr(mtq, args.qformat):
1093+
quant_cfg = getattr(mtq, args.qformat)
1094+
else:
1095+
raise AssertionError(
1096+
f"Unsupported quantization format: {args.qformat}, choices are: {list(QUANT_CFG_CHOICES.keys())}"
1097+
)
10731098

10741099
quant_cfg = build_quant_cfg(
10751100
args.qformat,
@@ -1104,7 +1129,7 @@ def quantize_main(
11041129
quant_cfg = copy.deepcopy(quant_cfg)
11051130
_set_kv_cache_constant_amax(quant_cfg["quant_cfg"])
11061131

1107-
if args.qformat in QUANT_CFG_CHOICES:
1132+
if args.qformat in QUANT_CFG_CHOICES or hasattr(mtq, args.qformat):
11081133
mono_quantize(
11091134
args,
11101135
quant_cfg,
@@ -1180,6 +1205,26 @@ def parse_args() -> argparse.Namespace:
11801205
type=str,
11811206
default="512",
11821207
)
1208+
parser.add_argument(
1209+
"--holdout_size",
1210+
help=(
1211+
"Number of holdout samples to save as a .pt file for evaluation. "
1212+
"Holdout samples are drawn from the same dataset immediately after "
1213+
"the calibration samples so there is no overlap. 0 disables holdout."
1214+
),
1215+
type=int,
1216+
default=0,
1217+
)
1218+
parser.add_argument(
1219+
"--calib_data_dir",
1220+
help=(
1221+
"Directory to save/load calib.pt and holdout.pt. "
1222+
"If both files exist, data is reloaded from disk instead of re-downloading. "
1223+
"Defaults to --export_path if not specified."
1224+
),
1225+
type=str,
1226+
default=None,
1227+
)
11831228
parser.add_argument(
11841229
"--calib_seq",
11851230
help="Maximum sequence length for calibration.",

modelopt/torch/quantization/model_calib.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1589,6 +1589,21 @@ def sequential_calibrate(
15891589

15901590
def _layer_forward_loop(m, _inputs=layer_inputs):
15911591
for args, kwargs_input in _inputs:
1592+
# Reset past_key_values to prevent the KV cache from
1593+
# accumulating across multiple forward replays (e.g.
1594+
# max_calibrate then Hessian collection in GPTQ).
1595+
# The layer doesn't need stale KV data — each replay
1596+
# should start with a fresh cache.
1597+
if (
1598+
"past_key_values" in kwargs_input
1599+
and kwargs_input["past_key_values"] is not None
1600+
):
1601+
kwargs_input = dict(kwargs_input)
1602+
cache = kwargs_input["past_key_values"]
1603+
if hasattr(cache, "reset"):
1604+
cache.reset()
1605+
else:
1606+
kwargs_input["past_key_values"] = None
15921607
m(*args, **kwargs_input)
15931608

15941609
calib_func(layer, _layer_forward_loop, **calib_kwargs)
@@ -1665,6 +1680,10 @@ def gptq(
16651680
print_rank_0("Updating weights using GPTQ algorithm...")
16661681
for handle in gptq_handles.values():
16671682
handle.update_weights(block_size, perc_damp)
1683+
1684+
# Disable weight quantizer after running GPTQ update since weights are already QDQ'ed
1685+
if hasattr(handle.module, "weight_quantizer"):
1686+
handle.module.weight_quantizer.disable()
16681687
handle.free()
16691688
del gptq_handles
16701689

0 commit comments

Comments
 (0)