Skip to content

Commit f0519b1

Browse files
committed
Add support for Qwen3Omni30B thinking model
Signed-off-by: ajrasane <131806219+ajrasane@users.noreply.github.com>
1 parent 944dd1a commit f0519b1

8 files changed

Lines changed: 494 additions & 66 deletions

File tree

examples/llm_ptq/example_utils.py

Lines changed: 10 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,11 @@
4242
snapshot_download = None
4343

4444
import modelopt.torch.quantization as mtq
45-
from modelopt.torch.utils.image_processor import BaseImageProcessor, MllamaImageProcessor
45+
from modelopt.torch.utils.image_processor import (
46+
BaseImageProcessor,
47+
MllamaImageProcessor,
48+
Qwen3OmniImageProcessor,
49+
)
4650

4751
SPECULATIVE_MODEL_LIST = ["Eagle", "Medusa"]
4852

@@ -284,7 +288,7 @@ def get_processor(
284288
if attn_implementation is not None:
285289
model_kwargs["attn_implementation"] = attn_implementation
286290

287-
if model_type == "whisper":
291+
if model_type in ("whisper", "mllama", "qwen3omni"):
288292
processor = AutoProcessor.from_pretrained(
289293
ckpt_path,
290294
padding_side="left",
@@ -296,20 +300,11 @@ def get_processor(
296300
f"Pad token for {ckpt_path} cannot be set!"
297301
)
298302

303+
if model_type == "mllama":
304+
return MllamaImageProcessor(processor, device)
305+
elif model_type == "qwen3omni":
306+
return Qwen3OmniImageProcessor(processor, device)
299307
return processor
300-
elif model_type == "mllama":
301-
processor = AutoProcessor.from_pretrained(
302-
ckpt_path,
303-
padding_side="left",
304-
**model_kwargs,
305-
)
306-
if processor.tokenizer.pad_token is None:
307-
processor.tokenizer.pad_token = processor.tokenizer.eos_token
308-
assert processor.tokenizer.pad_token is not None, (
309-
f"Pad token for {ckpt_path} cannot be set!"
310-
)
311-
312-
return MllamaImageProcessor(processor, device)
313308

314309
return None
315310

examples/llm_ptq/hf_ptq.py

Lines changed: 139 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,9 @@
1414
# limitations under the License.
1515

1616
import argparse
17+
import io
1718
import random
19+
import sys
1820
import time
1921
import warnings
2022
from typing import Any
@@ -61,12 +63,26 @@
6163
create_forward_loop,
6264
get_dataset_dataloader,
6365
get_max_batch_size,
66+
get_qwen3omni_text_dataloader,
6467
get_supported_datasets,
6568
)
66-
from modelopt.torch.utils.image_processor import BaseImageProcessor, MllamaImageProcessor
69+
from modelopt.torch.utils.image_processor import (
70+
BaseImageProcessor,
71+
MllamaImageProcessor,
72+
Qwen3OmniImageProcessor,
73+
Qwen3OmniTextProcessor,
74+
)
6775
from modelopt.torch.utils.memory_monitor import launch_memory_monitor
6876
from modelopt.torch.utils.speech_dataset_utils import get_speech_dataset_dataloader
69-
from modelopt.torch.utils.vlm_dataset_utils import get_vlm_dataset_dataloader
77+
from modelopt.torch.utils.video_dataset_utils import (
78+
Qwen3OmniVideoProcessor,
79+
get_supported_video_datasets,
80+
get_video_dataset_dataloader,
81+
)
82+
from modelopt.torch.utils.vlm_dataset_utils import (
83+
get_supported_vlm_datasets,
84+
get_vlm_dataset_dataloader,
85+
)
7086

7187
RAND_SEED = 1234
7288

@@ -179,6 +195,51 @@ def make_calib_dataloader(
179195
batch_size=args.batch_size,
180196
num_samples=args.calib_size[0],
181197
)
198+
elif model_type == "qwen3omni":
199+
assert processor is not None, "The processor must be set for qwen3omni model."
200+
dataset_name = args.dataset[0] if args.dataset else "cnn_dailymail"
201+
# Check if using video dataset (e.g., finevideo)
202+
if dataset_name in get_supported_video_datasets():
203+
video_processor = Qwen3OmniVideoProcessor(
204+
processor.tokenizer if hasattr(processor, "tokenizer") else processor,
205+
device=device,
206+
dtype=language_model.dtype,
207+
use_audio_in_video=True,
208+
)
209+
calib_dataloader = get_video_dataset_dataloader(
210+
dataset_name=dataset_name,
211+
processor=video_processor,
212+
batch_size=args.batch_size,
213+
num_samples=args.calib_size[0],
214+
)
215+
elif dataset_name in get_supported_vlm_datasets():
216+
assert isinstance(processor, Qwen3OmniImageProcessor), (
217+
"The Qwen3OmniImageProcessor must be set."
218+
)
219+
# Set the dtype for proper tensor conversion in collate_function
220+
processor.dtype = language_model.dtype
221+
calib_dataloader = get_vlm_dataset_dataloader(
222+
dataset_name=dataset_name,
223+
processor=processor,
224+
batch_size=args.batch_size,
225+
num_samples=args.calib_size[0],
226+
)
227+
else:
228+
# Text-only datasets (e.g., cnn_dailymail)
229+
# Use Qwen3OmniTextProcessor to apply proper conversation template
230+
# See: https://huggingface.co/Qwen/Qwen3-Omni-30B-A3B-Thinking
231+
text_processor = Qwen3OmniTextProcessor(
232+
processor=processor.tokenizer, # Pass the underlying HF processor
233+
device=device,
234+
dtype=language_model.dtype,
235+
)
236+
calib_dataloader = get_qwen3omni_text_dataloader(
237+
dataset_name=dataset_name,
238+
processor=text_processor,
239+
batch_size=args.batch_size,
240+
num_samples=args.calib_size[0],
241+
)
242+
print(f"Selected dataset for calibration: {dataset_name}")
182243
elif model_type == "whisper":
183244
assert processor is not None and isinstance(processor, WhisperProcessor), (
184245
"The AutoProcessor must be set."
@@ -349,6 +410,9 @@ def load_model(args: argparse.Namespace):
349410
calibration_only = True
350411

351412
model_type = get_model_type(full_model)
413+
if model_type == "qwen3omni":
414+
print("Disabling talker for Qwen3Omni model")
415+
full_model.disable_talker()
352416

353417
device = full_model.device
354418
if hasattr(full_model, "model"):
@@ -360,7 +424,7 @@ def load_model(args: argparse.Namespace):
360424
default_pad_token = None
361425

362426
is_nemotron_vl_model = is_nemotron_vl(full_model)
363-
if model_type == "mllama":
427+
if model_type in ["mllama", "qwen3omni"]:
364428
processor = get_processor(
365429
args.pyt_ckpt_path,
366430
model_type,
@@ -502,6 +566,15 @@ def mono_quantize(
502566
quant_cfg["quant_cfg"]["*radio*"] = {"enable": False}
503567
quant_cfg["quant_cfg"]["*visual*"] = {"enable": False}
504568

569+
# For Qwen3Omni models, disable quantization of conv layers
570+
if model_type == "qwen3omni":
571+
print(
572+
"Disabling quantization for conv layers, audio tower and visual encoder in Qwen3Omni model"
573+
)
574+
quant_cfg["quant_cfg"]["*conv*"] = {"enable": False}
575+
quant_cfg["quant_cfg"]["*audio_tower*"] = {"enable": False}
576+
quant_cfg["quant_cfg"]["*visual*"] = {"enable": False}
577+
505578
if not model_is_already_quantized or calibration_only:
506579
if model_type == "gptoss" and args.qformat == "nvfp4_mlp_only":
507580
print("Applying nvfp4 quantization (MoE only) for gpt-oss")
@@ -662,9 +735,10 @@ def pre_quantize(
662735
663736
"""
664737
# Only run single sample for preview
665-
preview_input_ids = next(iter(calib_dataloader))[
666-
"input_features" if model_type == "whisper" else "input_ids"
667-
][0:1]
738+
calib_batch = next(iter(calib_dataloader))
739+
preview_input_ids = calib_batch["input_features" if model_type == "whisper" else "input_ids"][
740+
0:1
741+
]
668742

669743
# Generate preview before quantization
670744
if model_type == "deepseek":
@@ -679,13 +753,24 @@ def pre_quantize(
679753
"before quantization",
680754
allow_fallback=True,
681755
)
756+
elif model_type == "qwen3omni":
757+
# Qwen3Omni returns (text_ids, audio) tuple; text_ids has .sequences
758+
# Pass full batch with all multimodal inputs
759+
result = full_model.generate(**calib_batch, max_new_tokens=100)
760+
if isinstance(result, tuple):
761+
text_ids, _ = result
762+
generated_ids_before_ptq = (
763+
text_ids.sequences if hasattr(text_ids, "sequences") else text_ids
764+
)
765+
else:
766+
generated_ids_before_ptq = result
682767
else:
683768
# Standard generation for non-Nemotron VL models
684769
generated_ids_before_ptq = full_model.generate(preview_input_ids, max_new_tokens=100)
685770
if model_type == "gptoss" and args.qformat == "nvfp4_mlp_only":
686771
print("Applying nvfp4 quantization (MoE only) for gpt-oss")
687772

688-
return preview_input_ids, generated_ids_before_ptq
773+
return preview_input_ids, generated_ids_before_ptq, calib_batch
689774

690775

691776
def post_quantize(
@@ -698,6 +783,7 @@ def post_quantize(
698783
generated_ids_before_ptq,
699784
is_nemotron_vl_model,
700785
first_text_speech_dataset,
786+
calib_batch: dict | None = None,
701787
):
702788
"""
703789
Processing after the quantization.
@@ -708,13 +794,37 @@ def post_quantize(
708794
"""
709795

710796
if args.verbose:
711-
mtq.print_quant_summary(full_model)
797+
if args.quant_summary_path:
798+
# Capture the summary output to a file
799+
old_stdout = sys.stdout
800+
sys.stdout = buffer = io.StringIO()
801+
try:
802+
mtq.print_quant_summary(full_model)
803+
finally:
804+
sys.stdout = old_stdout
805+
summary = buffer.getvalue()
806+
with open(args.quant_summary_path, "w") as f:
807+
f.write(summary)
808+
print(f"Quantization summary saved to {args.quant_summary_path}")
809+
else:
810+
mtq.print_quant_summary(full_model)
712811

713812
# Run some samples
714813
torch.cuda.empty_cache()
715814
generated_ids_after_ptq = None
716815
if generated_ids_before_ptq is None:
717816
pass
817+
elif model_type == "qwen3omni":
818+
# Qwen3Omni returns (text_ids, audio) tuple; text_ids has .sequences
819+
# Pass full batch with all multimodal inputs
820+
result = full_model.generate(**calib_batch, max_new_tokens=100)
821+
if isinstance(result, tuple):
822+
text_ids, _ = result
823+
generated_ids_after_ptq = (
824+
text_ids.sequences if hasattr(text_ids, "sequences") else text_ids
825+
)
826+
else:
827+
generated_ids_after_ptq = result
718828
elif model_type != "llama4" and not is_nemotron_vl_model:
719829
# Our fake quantizer may not be fully compatible with torch.compile.
720830
generated_ids_after_ptq = full_model.generate(preview_input_ids, max_new_tokens=100)
@@ -733,12 +843,13 @@ def post_quantize(
733843
)
734844

735845
def input_decode(input_ids):
736-
if processor is not None and isinstance(processor, MllamaImageProcessor):
737-
return processor.tokenizer.batch_decode(input_ids)
846+
# BaseImageProcessor covers MllamaImageProcessor and Qwen3OmniImageProcessor
847+
if processor is not None and isinstance(processor, BaseImageProcessor):
848+
return processor.tokenizer.batch_decode(input_ids, skip_special_tokens=True)
738849
elif processor is not None and isinstance(processor, WhisperProcessor):
739850
return first_text_speech_dataset
740851
elif tokenizer is not None:
741-
return tokenizer.batch_decode(input_ids)
852+
return tokenizer.batch_decode(input_ids, skip_special_tokens=True)
742853
else:
743854
raise ValueError("The processor or tokenizer must be set")
744855

@@ -750,6 +861,12 @@ def output_decode(generated_ids, input_shape):
750861
return tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
751862
elif processor is not None and isinstance(processor, MllamaImageProcessor):
752863
return processor.tokenizer.batch_decode(generated_ids[:, input_shape:])
864+
elif processor is not None and isinstance(processor, Qwen3OmniImageProcessor):
865+
return processor.tokenizer.batch_decode(
866+
generated_ids[:, input_shape:],
867+
skip_special_tokens=True,
868+
clean_up_tokenization_spaces=False,
869+
)
753870
elif tokenizer is not None:
754871
return tokenizer.batch_decode(generated_ids[:, input_shape:])
755872
else:
@@ -831,7 +948,7 @@ def quantize_main(
831948
# Detect if this is a Nemotron VL model using architecture-based detection
832949
is_nemotron_vl_model = is_nemotron_vl(full_model)
833950

834-
preview_input_ids, generated_ids_before_ptq = pre_quantize(
951+
preview_input_ids, generated_ids_before_ptq, calib_batch = pre_quantize(
835952
args, full_model, model_type, tokenizer, calib_dataloader, is_nemotron_vl_model
836953
)
837954

@@ -903,6 +1020,7 @@ def quantize_main(
9031020
generated_ids_before_ptq,
9041021
is_nemotron_vl_model,
9051022
first_text_speech_dataset,
1023+
calib_batch,
9061024
)
9071025
export_quantized(
9081026
args,
@@ -1083,6 +1201,15 @@ def parse_args() -> argparse.Namespace:
10831201
"(sensitivity scores, costs, etc.). Only used when auto_quantize_bits is specified."
10841202
),
10851203
)
1204+
parser.add_argument(
1205+
"--quant_summary_path",
1206+
type=str,
1207+
default=None,
1208+
help=(
1209+
"Path to save the quantization summary. If not specified, summary is printed to stdout. "
1210+
"Requires --verbose to be enabled (default: True)."
1211+
),
1212+
)
10861213

10871214
return parser.parse_args()
10881215

0 commit comments

Comments
 (0)