|
71 | 71 | ) |
72 | 72 | from modelopt.torch.utils.dataset_utils import ( |
73 | 73 | create_forward_loop, |
| 74 | + get_calib_and_holdout_dataloaders, |
74 | 75 | get_dataset_dataloader, |
75 | 76 | get_max_batch_size, |
76 | 77 | get_supported_datasets, |
@@ -203,9 +204,10 @@ def make_calib_dataloader( |
203 | 204 | tokenizer: PreTrainedTokenizerBase | None, |
204 | 205 | device: torch.device, |
205 | 206 | model_type: str | None, |
206 | | -) -> tuple[DataLoader | _DeviceDataLoader, str | None]: |
| 207 | +) -> tuple[DataLoader | _DeviceDataLoader, str | None, Path | None]: |
207 | 208 | calib_dataloader = None |
208 | 209 | first_text_speech_dataset = None |
| 210 | + holdout_path = None |
209 | 211 | if args.specdec_offline_dataset is not None: |
210 | 212 | offline_data_path = Path(args.specdec_offline_dataset) |
211 | 213 | dumped_files = sorted(str(p) for p in offline_data_path.glob("*.pt")) |
@@ -283,15 +285,29 @@ def make_calib_dataloader( |
283 | 285 | include_labels = ( |
284 | 286 | args.auto_quantize_bits is not None and args.auto_quantize_method == "gradient" |
285 | 287 | ) |
286 | | - calib_dataloader = get_dataset_dataloader( |
287 | | - dataset_name=args.dataset, |
288 | | - tokenizer=tokenizer, |
289 | | - batch_size=args.batch_size, |
290 | | - num_samples=args.calib_size, |
291 | | - device=device, |
292 | | - include_labels=include_labels, |
293 | | - ) |
294 | | - return calib_dataloader, first_text_speech_dataset |
| 288 | + |
| 289 | + if args.holdout_size > 0: |
| 290 | + calib_dataloader, holdout_path = get_calib_and_holdout_dataloaders( |
| 291 | + dataset_name=args.dataset, |
| 292 | + tokenizer=tokenizer, |
| 293 | + batch_size=args.batch_size, |
| 294 | + calib_size=args.calib_size, |
| 295 | + holdout_size=args.holdout_size, |
| 296 | + max_sample_length=args.calib_seq, |
| 297 | + device=device, |
| 298 | + include_labels=include_labels, |
| 299 | + save_dir=args.calib_data_dir, |
| 300 | + ) |
| 301 | + else: |
| 302 | + calib_dataloader = get_dataset_dataloader( |
| 303 | + dataset_name=args.dataset, |
| 304 | + tokenizer=tokenizer, |
| 305 | + batch_size=args.batch_size, |
| 306 | + num_samples=args.calib_size, |
| 307 | + device=device, |
| 308 | + include_labels=include_labels, |
| 309 | + ) |
| 310 | + return calib_dataloader, first_text_speech_dataset, holdout_path |
295 | 311 |
|
296 | 312 |
|
297 | 313 | def auto_quantize( |
@@ -419,10 +435,15 @@ def load_model(args: argparse.Namespace): |
419 | 435 | attn_implementation=args.attn_implementation, |
420 | 436 | ) |
421 | 437 | else: |
422 | | - assert args.qformat in QUANT_CFG_CHOICES, ( |
423 | | - f"Quantization format is not supported for low memory mode. Supported formats: {QUANT_CFG_CHOICES.keys()}" |
424 | | - ) |
425 | | - quant_cfg = QUANT_CFG_CHOICES[args.qformat] |
| 438 | + if args.qformat in QUANT_CFG_CHOICES: |
| 439 | + quant_cfg = QUANT_CFG_CHOICES[args.qformat] |
| 440 | + elif hasattr(mtq, args.qformat): |
| 441 | + quant_cfg = getattr(mtq, args.qformat) |
| 442 | + else: |
| 443 | + raise AssertionError( |
| 444 | + f"Quantization format is not supported for low memory mode. " |
| 445 | + f"Supported formats: {QUANT_CFG_CHOICES.keys()}" |
| 446 | + ) |
426 | 447 | if args.kv_cache_qformat != "none": |
427 | 448 | quant_cfg = mtq.utils.update_quant_cfg_with_kv_cache_quant( |
428 | 449 | quant_cfg, |
@@ -1028,7 +1049,7 @@ def quantize_main( |
1028 | 1049 |
|
1029 | 1050 | print(f"Use calib batch_size {args.batch_size}") |
1030 | 1051 |
|
1031 | | - calib_dataloader, first_text_speech_dataset = make_calib_dataloader( |
| 1052 | + calib_dataloader, first_text_speech_dataset, holdout_path = make_calib_dataloader( |
1032 | 1053 | args, language_model, processor, tokenizer, device, model_type |
1033 | 1054 | ) |
1034 | 1055 |
|
@@ -1066,10 +1087,14 @@ def quantize_main( |
1066 | 1087 | "Plain quantization supports only one quantization format." |
1067 | 1088 | ) |
1068 | 1089 |
|
1069 | | - assert args.qformat in QUANT_CFG_CHOICES, ( |
1070 | | - f"Unsupported quantization format: {args.qformat}, choices are: {list(QUANT_CFG_CHOICES.keys())}" |
1071 | | - ) |
1072 | | - quant_cfg = QUANT_CFG_CHOICES[args.qformat] |
| 1090 | + if args.qformat in QUANT_CFG_CHOICES: |
| 1091 | + quant_cfg = QUANT_CFG_CHOICES[args.qformat] |
| 1092 | + elif hasattr(mtq, args.qformat): |
| 1093 | + quant_cfg = getattr(mtq, args.qformat) |
| 1094 | + else: |
| 1095 | + raise AssertionError( |
| 1096 | + f"Unsupported quantization format: {args.qformat}, choices are: {list(QUANT_CFG_CHOICES.keys())}" |
| 1097 | + ) |
1073 | 1098 |
|
1074 | 1099 | quant_cfg = build_quant_cfg( |
1075 | 1100 | args.qformat, |
@@ -1104,7 +1129,7 @@ def quantize_main( |
1104 | 1129 | quant_cfg = copy.deepcopy(quant_cfg) |
1105 | 1130 | _set_kv_cache_constant_amax(quant_cfg["quant_cfg"]) |
1106 | 1131 |
|
1107 | | - if args.qformat in QUANT_CFG_CHOICES: |
| 1132 | + if args.qformat in QUANT_CFG_CHOICES or hasattr(mtq, args.qformat): |
1108 | 1133 | mono_quantize( |
1109 | 1134 | args, |
1110 | 1135 | quant_cfg, |
@@ -1180,6 +1205,26 @@ def parse_args() -> argparse.Namespace: |
1180 | 1205 | type=str, |
1181 | 1206 | default="512", |
1182 | 1207 | ) |
| 1208 | + parser.add_argument( |
| 1209 | + "--holdout_size", |
| 1210 | + help=( |
| 1211 | + "Number of holdout samples to save as a .pt file for evaluation. " |
| 1212 | + "Holdout samples are drawn from the same dataset immediately after " |
| 1213 | + "the calibration samples so there is no overlap. 0 disables holdout." |
| 1214 | + ), |
| 1215 | + type=int, |
| 1216 | + default=0, |
| 1217 | + ) |
| 1218 | + parser.add_argument( |
| 1219 | + "--calib_data_dir", |
| 1220 | + help=( |
| 1221 | + "Directory to save/load calib.pt and holdout.pt. " |
| 1222 | + "If both files exist, data is reloaded from disk instead of re-downloading. " |
| 1223 | + "Defaults to --export_path if not specified." |
| 1224 | + ), |
| 1225 | + type=str, |
| 1226 | + default=None, |
| 1227 | + ) |
1183 | 1228 | parser.add_argument( |
1184 | 1229 | "--calib_seq", |
1185 | 1230 | help="Maximum sequence length for calibration.", |
|
0 commit comments