Skip to content

Commit dfedafa

Browse files
committed
Optimize calibration for text data
Signed-off-by: ajrasane <131806219+ajrasane@users.noreply.github.com>
1 parent f0519b1 commit dfedafa

4 files changed

Lines changed: 68 additions & 63 deletions

File tree

examples/llm_ptq/hf_ptq.py

Lines changed: 51 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -63,14 +63,12 @@
6363
create_forward_loop,
6464
get_dataset_dataloader,
6565
get_max_batch_size,
66-
get_qwen3omni_text_dataloader,
6766
get_supported_datasets,
6867
)
6968
from modelopt.torch.utils.image_processor import (
7069
BaseImageProcessor,
7170
MllamaImageProcessor,
7271
Qwen3OmniImageProcessor,
73-
Qwen3OmniTextProcessor,
7472
)
7573
from modelopt.torch.utils.memory_monitor import launch_memory_monitor
7674
from modelopt.torch.utils.speech_dataset_utils import get_speech_dataset_dataloader
@@ -196,50 +194,47 @@ def make_calib_dataloader(
196194
num_samples=args.calib_size[0],
197195
)
198196
elif model_type == "qwen3omni":
199-
assert processor is not None, "The processor must be set for qwen3omni model."
200197
dataset_name = args.dataset[0] if args.dataset else "cnn_dailymail"
201198
# Check if using video dataset (e.g., finevideo)
202-
if dataset_name in get_supported_video_datasets():
203-
video_processor = Qwen3OmniVideoProcessor(
204-
processor.tokenizer if hasattr(processor, "tokenizer") else processor,
205-
device=device,
206-
dtype=language_model.dtype,
207-
use_audio_in_video=True,
208-
)
209-
calib_dataloader = get_video_dataset_dataloader(
210-
dataset_name=dataset_name,
211-
processor=video_processor,
212-
batch_size=args.batch_size,
213-
num_samples=args.calib_size[0],
214-
)
215-
elif dataset_name in get_supported_vlm_datasets():
216-
assert isinstance(processor, Qwen3OmniImageProcessor), (
217-
"The Qwen3OmniImageProcessor must be set."
218-
)
219-
# Set the dtype for proper tensor conversion in collate_function
220-
processor.dtype = language_model.dtype
221-
calib_dataloader = get_vlm_dataset_dataloader(
222-
dataset_name=dataset_name,
223-
processor=processor,
224-
batch_size=args.batch_size,
225-
num_samples=args.calib_size[0],
226-
)
199+
if processor is not None:
200+
if dataset_name in get_supported_video_datasets():
201+
video_processor = Qwen3OmniVideoProcessor(
202+
processor.tokenizer if hasattr(processor, "tokenizer") else processor,
203+
device=device,
204+
dtype=language_model.dtype,
205+
use_audio_in_video=True,
206+
)
207+
calib_dataloader = get_video_dataset_dataloader(
208+
dataset_name=dataset_name,
209+
processor=video_processor,
210+
batch_size=args.batch_size,
211+
num_samples=args.calib_size[0],
212+
)
213+
elif dataset_name in get_supported_vlm_datasets():
214+
assert isinstance(processor, Qwen3OmniImageProcessor), (
215+
"The Qwen3OmniImageProcessor must be set."
216+
)
217+
# Set the dtype for proper tensor conversion in collate_function
218+
processor.dtype = language_model.dtype
219+
calib_dataloader = get_vlm_dataset_dataloader(
220+
dataset_name=dataset_name,
221+
processor=processor,
222+
batch_size=args.batch_size,
223+
num_samples=args.calib_size[0],
224+
)
227225
else:
228-
# Text-only datasets (e.g., cnn_dailymail)
229-
# Use Qwen3OmniTextProcessor to apply proper conversation template
230-
# See: https://huggingface.co/Qwen/Qwen3-Omni-30B-A3B-Thinking
231-
text_processor = Qwen3OmniTextProcessor(
232-
processor=processor.tokenizer, # Pass the underlying HF processor
233-
device=device,
234-
dtype=language_model.dtype,
226+
# Labels are only needed for gradient-based auto_quantize
227+
include_labels = (
228+
args.auto_quantize_bits is not None and args.auto_quantize_method == "gradient"
235229
)
236-
calib_dataloader = get_qwen3omni_text_dataloader(
237-
dataset_name=dataset_name,
238-
processor=text_processor,
230+
calib_dataloader = get_dataset_dataloader(
231+
dataset_name=args.dataset,
232+
tokenizer=tokenizer,
239233
batch_size=args.batch_size,
240-
num_samples=args.calib_size[0],
234+
num_samples=args.calib_size,
235+
device=device,
236+
include_labels=include_labels,
241237
)
242-
print(f"Selected dataset for calibration: {dataset_name}")
243238
elif model_type == "whisper":
244239
assert processor is not None and isinstance(processor, WhisperProcessor), (
245240
"The AutoProcessor must be set."
@@ -410,9 +405,6 @@ def load_model(args: argparse.Namespace):
410405
calibration_only = True
411406

412407
model_type = get_model_type(full_model)
413-
if model_type == "qwen3omni":
414-
print("Disabling talker for Qwen3Omni model")
415-
full_model.disable_talker()
416408

417409
device = full_model.device
418410
if hasattr(full_model, "model"):
@@ -432,6 +424,14 @@ def load_model(args: argparse.Namespace):
432424
trust_remote_code=args.trust_remote_code,
433425
attn_implementation=args.attn_implementation,
434426
)
427+
if model_type == "qwen3omni":
428+
print("Disabling talker for Qwen3Omni model")
429+
full_model.disable_talker()
430+
language_model = full_model.thinker.model
431+
tokenizer = processor.tokenizer.tokenizer
432+
processor = None
433+
default_padding_side = tokenizer.padding_side
434+
default_pad_token = tokenizer.pad_token
435435
elif model_type == "whisper":
436436
processor = get_processor(
437437
args.pyt_ckpt_path,
@@ -567,13 +567,16 @@ def mono_quantize(
567567
quant_cfg["quant_cfg"]["*visual*"] = {"enable": False}
568568

569569
# For Qwen3Omni models, disable quantization of conv layers
570+
generation_kwargs = {}
570571
if model_type == "qwen3omni":
571572
print(
572573
"Disabling quantization for conv layers, audio tower and visual encoder in Qwen3Omni model"
573574
)
574575
quant_cfg["quant_cfg"]["*conv*"] = {"enable": False}
575576
quant_cfg["quant_cfg"]["*audio_tower*"] = {"enable": False}
576577
quant_cfg["quant_cfg"]["*visual*"] = {"enable": False}
578+
generation_kwargs["return_audio"] = False
579+
generation_kwargs["thinker_max_new_tokens"] = 1
577580

578581
if not model_is_already_quantized or calibration_only:
579582
if model_type == "gptoss" and args.qformat == "nvfp4_mlp_only":
@@ -592,7 +595,9 @@ def mono_quantize(
592595
if args.calib_with_images and is_nemotron_vl_model:
593596
calibrate_loop = create_vlm_calibration_loop(full_model, calib_dataloader)
594597
else:
595-
calibrate_loop = create_forward_loop(dataloader=calib_dataloader)
598+
calibrate_loop = create_forward_loop(
599+
dataloader=calib_dataloader, generation_kwargs=generation_kwargs
600+
)
596601

597602
if calibration_only:
598603
language_model = mtq.calibrate(
@@ -756,7 +761,7 @@ def pre_quantize(
756761
elif model_type == "qwen3omni":
757762
# Qwen3Omni returns (text_ids, audio) tuple; text_ids has .sequences
758763
# Pass full batch with all multimodal inputs
759-
result = full_model.generate(**calib_batch, max_new_tokens=100)
764+
result = full_model.generate(**calib_batch, return_audio=False, thinker_max_new_tokens=100)
760765
if isinstance(result, tuple):
761766
text_ids, _ = result
762767
generated_ids_before_ptq = (
@@ -817,7 +822,7 @@ def post_quantize(
817822
elif model_type == "qwen3omni":
818823
# Qwen3Omni returns (text_ids, audio) tuple; text_ids has .sequences
819824
# Pass full batch with all multimodal inputs
820-
result = full_model.generate(**calib_batch, max_new_tokens=100)
825+
result = full_model.generate(**calib_batch, return_audio=False, thinker_max_new_tokens=100)
821826
if isinstance(result, tuple):
822827
text_ids, _ = result
823828
generated_ids_after_ptq = (

modelopt/torch/export/unified_export_hf.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -319,7 +319,7 @@ def llm_dummy_forward():
319319
if getattr(model.config, "is_encoder_decoder", False):
320320
# For encoder-decoder models, we need to pass both the encoder and decoder input ids
321321
model(fake_input, decoder_input_ids=decoder_fake_input)
322-
elif is_vl_model and "nemotron" in model_type:
322+
elif (is_vl_model and "nemotron" in model_type) or model_type.startswith("qwen3omni"):
323323
# For Nemotron VL models, try to run optimization on just the language model part
324324
language_model_lineage = get_language_model_from_vl(model)
325325

@@ -333,7 +333,7 @@ def llm_dummy_forward():
333333
language_model(fake_input)
334334
else:
335335
raise ValueError(
336-
f"Cannot extract language_model from Nemotron VL model (type: {model_type}). "
336+
f"Cannot extract language_model from VL model (type: {model_type}). "
337337
"This is required for requantization/resmoothing optimization. "
338338
"Please ensure the model architecture is supported or file an issue."
339339
)

modelopt/torch/quantization/plugins/huggingface.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -798,14 +798,9 @@ def unpack_weight(self):
798798

799799
try:
800800
from transformers.models.qwen3_omni_moe.modeling_qwen3_omni_moe import (
801-
Qwen3OmniMoeTalkerTextSparseMoeBlock,
802801
Qwen3OmniMoeThinkerTextSparseMoeBlock,
803802
)
804803

805-
if Qwen3OmniMoeTalkerTextSparseMoeBlock not in QuantModuleRegistry:
806-
QuantModuleRegistry.register(
807-
{Qwen3OmniMoeTalkerTextSparseMoeBlock: "hf.Qwen3OmniMoeTalkerTextSparseMoeBlock"}
808-
)(_QuantSparseMoe)
809804
if Qwen3OmniMoeThinkerTextSparseMoeBlock not in QuantModuleRegistry:
810805
QuantModuleRegistry.register(
811806
{Qwen3OmniMoeThinkerTextSparseMoeBlock: "hf.Qwen3OmniMoeThinkerTextSparseMoeBlock"}

modelopt/torch/utils/dataset_utils.py

Lines changed: 15 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -450,12 +450,13 @@ def _get_free_gpu_mem():
450450
return 512
451451

452452

453-
def _process_batch(batch_data, infer_method, max_working_batch_size=None):
453+
def _process_batch(batch_data, infer_method, generation_kwargs={}, max_working_batch_size=None):
454454
"""Process a batch of data through the model's inference method.
455455
456456
Args:
457457
batch_data: Dictionary containing the batch data
458458
infer_method: Model's inference method (either forward or generate)
459+
generation_kwargs: Keyword arguments to pass to the model.generate() method.
459460
max_working_batch_size: Maximum batch size known to work without OOM
460461
461462
Returns:
@@ -493,7 +494,7 @@ def _process_batch(batch_data, infer_method, max_working_batch_size=None):
493494

494495
# Try processing with current batch size
495496
try:
496-
infer_method(**batch_data)
497+
infer_method(**batch_data, **generation_kwargs)
497498
return (
498499
batch_size
499500
if max_working_batch_size is None
@@ -524,24 +525,27 @@ def _process_batch(batch_data, infer_method, max_working_batch_size=None):
524525
return max_working_batch_size
525526

526527

527-
def _forward_loop(model: torch.nn.Module, dataloader: DataLoader) -> None:
528+
def _forward_loop(
529+
model: torch.nn.Module, dataloader: DataLoader, generation_kwargs: dict = {}
530+
) -> None:
528531
"""Runs forward passes through the model using data from the dataloader.
529532
530533
Args:
531534
model: The PyTorch model to run inference on
532535
dataloader: DataLoader containing the batched input data
536+
generation_kwargs: Keyword arguments to pass to the model.generate() method.
533537
"""
534538
with torch.no_grad():
535-
use_generate = _should_use_generate(model)
539+
# use_generate = _should_use_generate(model)
540+
use_generate = model_type_is_enc_dec(model)
536541
infer_method = model.generate if use_generate else model.forward
537542
max_working_batch_size = None # Initialize max working batch size as None
538543

539544
for _, data in enumerate(tqdm(dataloader)):
540-
# For generate(), add max_new_tokens to prevent indefinite generation during calibration
541-
if use_generate:
542-
data["max_new_tokens"] = 1
543545
# Process batch and update max working batch size
544-
max_working_batch_size = _process_batch(data, infer_method, max_working_batch_size)
546+
max_working_batch_size = _process_batch(
547+
data, infer_method, generation_kwargs, max_working_batch_size
548+
)
545549

546550

547551
def create_forward_loop(
@@ -554,6 +558,7 @@ def create_forward_loop(
554558
device: str | None = None,
555559
include_labels: bool = False,
556560
dataloader: DataLoader | None = None,
561+
generation_kwargs: dict = {},
557562
) -> Callable:
558563
"""Creates and returns a forward loop function configured for a specific model, dataset, and tokenizer.
559564
@@ -572,7 +577,7 @@ def create_forward_loop(
572577
device: Target device for the returned dataloader.
573578
include_labels: Whether to include labels in the dataloader.
574579
dataloader: If provided, use the provided dataloader instead.
575-
580+
generation_kwargs: Keyword arguments to pass to the model.generate() method.
576581
Example usage for quantization:
577582
578583
.. code-block:: python
@@ -611,7 +616,7 @@ def create_forward_loop(
611616
include_labels=include_labels,
612617
)
613618

614-
return lambda model: _forward_loop(model, dataloader)
619+
return lambda model: _forward_loop(model, dataloader, generation_kwargs)
615620

616621

617622
def model_type_is_enc_dec(model):

0 commit comments

Comments
 (0)