diff --git a/examples/llm_ptq/example_utils.py b/examples/llm_ptq/example_utils.py index aad29fc97c..03e0adbd67 100755 --- a/examples/llm_ptq/example_utils.py +++ b/examples/llm_ptq/example_utils.py @@ -42,7 +42,11 @@ snapshot_download = None import modelopt.torch.quantization as mtq -from modelopt.torch.utils.image_processor import BaseImageProcessor, MllamaImageProcessor +from modelopt.torch.utils.image_processor import ( + BaseImageProcessor, + MllamaImageProcessor, + Qwen3OmniImageProcessor, +) SPECULATIVE_MODEL_LIST = ["Eagle", "Medusa"] @@ -240,6 +244,27 @@ def build_quant_cfg( quant_cfg["quant_cfg"]["*self_attn.q*"] = {"enable": False} quant_cfg["quant_cfg"]["*self_attn.kv*"] = {"enable": False} + if model_type == "qwen3omni": + if qformat == "qwen3_nvfp4_qkv_disabled": + for proj in ["q_proj", "k_proj", "v_proj"]: + quant_cfg["quant_cfg"][f"*thinker.model.layers.*.self_attn.{proj}*"] = { + "enable": False + } + elif qformat == "qwen3_nvfp4_qkvo_disabled": + for proj in ["q_proj", "k_proj", "v_proj", "o_proj"]: + quant_cfg["quant_cfg"][f"*thinker.model.layers.*.self_attn.{proj}*"] = { + "enable": False + } + + elif qformat == "qwen3_nvfp4_first_and_last_n_disabled": + # Disable both first N and last N layers + total_layers = 48 + n_layers_to_disable = 4 + for i in range(n_layers_to_disable): + quant_cfg["quant_cfg"][f"*thinker.model.layers.{i}.*"] = {"enable": False} + for i in range(total_layers - n_layers_to_disable, total_layers): + quant_cfg["quant_cfg"][f"*thinker.model.layers.{i}.*"] = {"enable": False} + return quant_cfg @@ -310,6 +335,19 @@ def get_processor( ) return MllamaImageProcessor(processor, device) + elif model_type == "qwen3omni": + processor = AutoProcessor.from_pretrained( + ckpt_path, + padding_side="left", + **model_kwargs, + ) + if processor.tokenizer.pad_token is None: + processor.tokenizer.pad_token = processor.tokenizer.eos_token + assert processor.tokenizer.pad_token is not None, ( + f"Pad token for {ckpt_path} cannot be set!" + ) + + return Qwen3OmniImageProcessor(processor, device) return None diff --git a/examples/llm_ptq/generate_video_dataset.py b/examples/llm_ptq/generate_video_dataset.py new file mode 100644 index 0000000000..2f8d6bbf80 --- /dev/null +++ b/examples/llm_ptq/generate_video_dataset.py @@ -0,0 +1,112 @@ +#!/usr/bin/env python3 +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Script to pre-generate processed video dataset for Qwen3-Omni quantization.""" + +import argparse +import os + +import torch +from transformers import AutoProcessor + +from modelopt.torch.utils.video_dataset_utils import ( + Qwen3OmniVideoProcessor, + get_video_dataset_dataloader, +) + + +def main(): + parser = argparse.ArgumentParser(description="Generate processed video dataset cache") + parser.add_argument( + "--model-name", + type=str, + default="Qwen/Qwen3-Omni-30B-A3B-Thinking", + help="Model name or path for loading the processor", + ) + parser.add_argument( + "--dataset-name", + type=str, + default="finevideo", + help="Name of the video dataset to process", + ) + parser.add_argument( + "--num-samples", + type=int, + default=512, + help="Number of samples to process", + ) + parser.add_argument( + "--cache-dir", + type=str, + required=True, + help="Directory to save the processed dataset cache", + ) + parser.add_argument( + "--dtype", + type=str, + default="bfloat16", + choices=["float16", "bfloat16", "float32"], + help="Data type for processing", + ) + parser.add_argument( + "--no-audio", + action="store_true", + help="Disable audio extraction from videos", + ) + args = parser.parse_args() + + use_audio = not args.no_audio + + # Set dtype + dtype_map = { + "float16": torch.float16, + "bfloat16": torch.bfloat16, + "float32": torch.float32, + } + dtype = dtype_map[args.dtype] + + print(f"Loading processor from {args.model_name}...") + hf_processor = AutoProcessor.from_pretrained(args.model_name, trust_remote_code=True) + + print(f"Creating Qwen3OmniVideoProcessor (use_audio={use_audio}, dtype={args.dtype})...") + processor = Qwen3OmniVideoProcessor( + tokenizer=hf_processor, + device="cuda" if torch.cuda.is_available() else "cpu", + dtype=dtype, + use_audio_in_video=use_audio, + ) + + print(f"Processing {args.num_samples} samples from {args.dataset_name}...") + print(f"Cache directory: {args.cache_dir}") + + # This will process and save to cache + _ = get_video_dataset_dataloader( + dataset_name=args.dataset_name, + processor=processor, + batch_size=1, + num_samples=args.num_samples, + cache_dir=args.cache_dir, + ) + + # Cleanup temp files + processor.cleanup() + + cache_path = os.path.join(args.cache_dir, f"{args.dataset_name}_n{args.num_samples}_processed") + print(f"\nDone! Processed dataset saved to: {cache_path}") + + +if __name__ == "__main__": + main() diff --git a/examples/llm_ptq/hf_ptq.py b/examples/llm_ptq/hf_ptq.py index e32d0dae84..4f3a8af28a 100755 --- a/examples/llm_ptq/hf_ptq.py +++ b/examples/llm_ptq/hf_ptq.py @@ -14,7 +14,9 @@ # limitations under the License. import argparse +import io import random +import sys import time import warnings from typing import Any @@ -61,12 +63,26 @@ create_forward_loop, get_dataset_dataloader, get_max_batch_size, + get_qwen3omni_text_dataloader, get_supported_datasets, ) -from modelopt.torch.utils.image_processor import BaseImageProcessor, MllamaImageProcessor +from modelopt.torch.utils.image_processor import ( + BaseImageProcessor, + MllamaImageProcessor, + Qwen3OmniImageProcessor, + Qwen3OmniTextProcessor, +) from modelopt.torch.utils.memory_monitor import launch_memory_monitor from modelopt.torch.utils.speech_dataset_utils import get_speech_dataset_dataloader -from modelopt.torch.utils.vlm_dataset_utils import get_vlm_dataset_dataloader +from modelopt.torch.utils.video_dataset_utils import ( + Qwen3OmniVideoProcessor, + get_supported_video_datasets, + get_video_dataset_dataloader, +) +from modelopt.torch.utils.vlm_dataset_utils import ( + get_supported_vlm_datasets, + get_vlm_dataset_dataloader, +) RAND_SEED = 1234 @@ -86,6 +102,9 @@ "nvfp4_mlp_only": mtq.NVFP4_MLP_ONLY_CFG, "nvfp4_svdquant": mtq.NVFP4_SVDQUANT_DEFAULT_CFG, "mxfp8": mtq.MXFP8_DEFAULT_CFG, + "qwen3_nvfp4_qkv_disabled": mtq.NVFP4_DEFAULT_CFG, + "qwen3_nvfp4_qkvo_disabled": mtq.NVFP4_DEFAULT_CFG, + "qwen3_nvfp4_first_and_last_n_disabled": mtq.NVFP4_DEFAULT_CFG, } KV_QUANT_CFG_CHOICES = { @@ -179,6 +198,53 @@ def make_calib_dataloader( batch_size=args.batch_size, num_samples=args.calib_size[0], ) + elif model_type == "qwen3omni": + assert processor is not None, "The processor must be set for qwen3omni model." + dataset_name = args.dataset[0] if args.dataset else "cnn_dailymail" + # Check if using video dataset (e.g., finevideo) + if dataset_name in get_supported_video_datasets(): + video_processor = Qwen3OmniVideoProcessor( + processor.tokenizer if hasattr(processor, "tokenizer") else processor, + device=device, + dtype=language_model.dtype, + use_audio_in_video=True, + ) + calib_dataloader = get_video_dataset_dataloader( + dataset_name=dataset_name, + processor=video_processor, + batch_size=args.batch_size, + num_samples=args.calib_size[0], + ) + elif dataset_name in get_supported_vlm_datasets(): + assert isinstance(processor, Qwen3OmniImageProcessor), ( + "The Qwen3OmniImageProcessor must be set." + ) + # Set the dtype for proper tensor conversion in collate_function + processor.dtype = language_model.dtype + calib_dataloader = get_vlm_dataset_dataloader( + dataset_name=dataset_name, + processor=processor, + batch_size=args.batch_size, + num_samples=args.calib_size[0], + ) + else: + # Text-only datasets (e.g., cnn_dailymail) + # Use Qwen3OmniTextProcessor to apply proper conversation template + # See: https://huggingface.co/Qwen/Qwen3-Omni-30B-A3B-Thinking + text_processor = Qwen3OmniTextProcessor( + processor=processor.tokenizer, # Pass the underlying HF processor + device=device, + dtype=language_model.dtype, + ) + calib_dataloader = get_qwen3omni_text_dataloader( + dataset_name=dataset_name, + processor=text_processor, + batch_size=args.batch_size, + num_samples=args.calib_size[0], + max_sample_length=args.calib_seq, + device=device, + ) + print(f"Selected dataset for calibration: {dataset_name}") elif model_type == "whisper": assert processor is not None and isinstance(processor, WhisperProcessor), ( "The AutoProcessor must be set." @@ -324,11 +390,14 @@ def load_model(args: argparse.Namespace): use_seq_device_map=args.use_seq_device_map, attn_implementation=args.attn_implementation, ) + + quant_cfg = QUANT_CFG_CHOICES[args.qformat] else: assert args.qformat in QUANT_CFG_CHOICES, ( f"Quantization format is not supported for low memory mode. Supported formats: {QUANT_CFG_CHOICES.keys()}" ) quant_cfg = QUANT_CFG_CHOICES[args.qformat] + if args.kv_cache_qformat != "none": quant_cfg = mtq.utils.update_quant_cfg_with_kv_cache_quant( quant_cfg, @@ -349,10 +418,14 @@ def load_model(args: argparse.Namespace): calibration_only = True model_type = get_model_type(full_model) + if model_type == "qwen3omni": + print("Disabling talker for Qwen3Omni model") + full_model.disable_talker() device = full_model.device if hasattr(full_model, "model"): device = full_model.model.device + processor = None tokenizer = None language_model = full_model @@ -360,7 +433,8 @@ def load_model(args: argparse.Namespace): default_pad_token = None is_nemotron_vl_model = is_nemotron_vl(full_model) - if model_type == "mllama": + + if model_type in ["mllama", "qwen3omni"]: processor = get_processor( args.pyt_ckpt_path, model_type, @@ -502,6 +576,15 @@ def mono_quantize( quant_cfg["quant_cfg"]["*radio*"] = {"enable": False} quant_cfg["quant_cfg"]["*visual*"] = {"enable": False} + # For Qwen3Omni models, disable quantization of conv layers + if model_type == "qwen3omni": + print( + "Disabling quantization for conv layers, audio tower and visual encoder in Qwen3Omni model" + ) + quant_cfg["quant_cfg"]["*conv*"] = {"enable": False} + quant_cfg["quant_cfg"]["*audio_tower*"] = {"enable": False} + quant_cfg["quant_cfg"]["*visual*"] = {"enable": False} + if not model_is_already_quantized or calibration_only: if model_type == "gptoss" and args.qformat == "nvfp4_mlp_only": print("Applying nvfp4 quantization (MoE only) for gpt-oss") @@ -534,7 +617,6 @@ def mono_quantize( if language_model_lineage is not None: print("Updating full_model with quantized language_model...") language_model_lineage[-2].language_model = language_model - else: warnings.warn("Skipping quantization: model is already quantized.") @@ -628,6 +710,7 @@ def export_quantized( export_hf_checkpoint( full_model, export_dir=export_path, + save_modelopt_state=model_type == "qwen3omni", ) # Copy custom model files (Python files and JSON configs) if trust_remote_code is used @@ -662,9 +745,10 @@ def pre_quantize( """ # Only run single sample for preview - preview_input_ids = next(iter(calib_dataloader))[ - "input_features" if model_type == "whisper" else "input_ids" - ][0:1] + calib_batch = next(iter(calib_dataloader)) + preview_input_ids = calib_batch["input_features" if model_type == "whisper" else "input_ids"][ + 0:1 + ] # Generate preview before quantization if model_type == "deepseek": @@ -679,13 +763,24 @@ def pre_quantize( "before quantization", allow_fallback=True, ) + elif model_type == "qwen3omni": + # Qwen3Omni returns (text_ids, audio) tuple; text_ids has .sequences + # Pass full batch with all multimodal inputs + result = full_model.generate(**calib_batch, max_new_tokens=100) + if isinstance(result, tuple): + text_ids, _ = result + generated_ids_before_ptq = ( + text_ids.sequences if hasattr(text_ids, "sequences") else text_ids + ) + else: + generated_ids_before_ptq = result else: # Standard generation for non-Nemotron VL models generated_ids_before_ptq = full_model.generate(preview_input_ids, max_new_tokens=100) if model_type == "gptoss" and args.qformat == "nvfp4_mlp_only": print("Applying nvfp4 quantization (MoE only) for gpt-oss") - return preview_input_ids, generated_ids_before_ptq + return preview_input_ids, generated_ids_before_ptq, calib_batch def post_quantize( @@ -698,6 +793,7 @@ def post_quantize( generated_ids_before_ptq, is_nemotron_vl_model, first_text_speech_dataset, + calib_batch: dict | None = None, ): """ Processing after the quantization. @@ -708,13 +804,37 @@ def post_quantize( """ if args.verbose: - mtq.print_quant_summary(full_model) + if args.quant_summary_path: + # Capture the summary output to a file + old_stdout = sys.stdout + sys.stdout = buffer = io.StringIO() + try: + mtq.print_quant_summary(full_model) + finally: + sys.stdout = old_stdout + summary = buffer.getvalue() + with open(args.quant_summary_path, "w") as f: + f.write(summary) + print(f"Quantization summary saved to {args.quant_summary_path}") + else: + mtq.print_quant_summary(full_model) # Run some samples torch.cuda.empty_cache() generated_ids_after_ptq = None if generated_ids_before_ptq is None: pass + elif model_type == "qwen3omni": + # Qwen3Omni returns (text_ids, audio) tuple; text_ids has .sequences + # Pass full batch with all multimodal inputs + result = full_model.generate(**calib_batch, max_new_tokens=100) + if isinstance(result, tuple): + text_ids, _ = result + generated_ids_after_ptq = ( + text_ids.sequences if hasattr(text_ids, "sequences") else text_ids + ) + else: + generated_ids_after_ptq = result elif model_type != "llama4" and not is_nemotron_vl_model: # Our fake quantizer may not be fully compatible with torch.compile. generated_ids_after_ptq = full_model.generate(preview_input_ids, max_new_tokens=100) @@ -733,7 +853,8 @@ def post_quantize( ) def input_decode(input_ids): - if processor is not None and isinstance(processor, MllamaImageProcessor): + # BaseImageProcessor covers MllamaImageProcessor and Qwen3OmniImageProcessor + if processor is not None and isinstance(processor, BaseImageProcessor): return processor.tokenizer.batch_decode(input_ids) elif processor is not None and isinstance(processor, WhisperProcessor): return first_text_speech_dataset @@ -750,6 +871,12 @@ def output_decode(generated_ids, input_shape): return tokenizer.batch_decode(generated_ids, skip_special_tokens=True) elif processor is not None and isinstance(processor, MllamaImageProcessor): return processor.tokenizer.batch_decode(generated_ids[:, input_shape:]) + elif processor is not None and isinstance(processor, Qwen3OmniImageProcessor): + return processor.tokenizer.batch_decode( + generated_ids[:, input_shape:], + skip_special_tokens=True, + clean_up_tokenization_spaces=False, + ) elif tokenizer is not None: return tokenizer.batch_decode(generated_ids[:, input_shape:]) else: @@ -831,7 +958,7 @@ def quantize_main( # Detect if this is a Nemotron VL model using architecture-based detection is_nemotron_vl_model = is_nemotron_vl(full_model) - preview_input_ids, generated_ids_before_ptq = pre_quantize( + preview_input_ids, generated_ids_before_ptq, calib_batch = pre_quantize( args, full_model, model_type, tokenizer, calib_dataloader, is_nemotron_vl_model ) @@ -903,6 +1030,7 @@ def quantize_main( generated_ids_before_ptq, is_nemotron_vl_model, first_text_speech_dataset, + calib_batch, ) export_quantized( args, @@ -1083,6 +1211,15 @@ def parse_args() -> argparse.Namespace: "(sensitivity scores, costs, etc.). Only used when auto_quantize_bits is specified." ), ) + parser.add_argument( + "--quant_summary_path", + type=str, + default=None, + help=( + "Path to save the quantization summary. If not specified, summary is printed to stdout. " + "Requires --verbose to be enabled (default: True)." + ), + ) return parser.parse_args() diff --git a/examples/llm_ptq/run_quantized_qwen3omni.py b/examples/llm_ptq/run_quantized_qwen3omni.py new file mode 100644 index 0000000000..b11f8d37cc --- /dev/null +++ b/examples/llm_ptq/run_quantized_qwen3omni.py @@ -0,0 +1,150 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +"""Script to load and run a quantized Qwen3Omni model from export_hf_checkpoint or mto.save().""" + +import argparse +import time + +import torch +from qwen_omni_utils import process_mm_info +from transformers import Qwen3OmniMoeForConditionalGeneration, Qwen3OmniMoeProcessor + +import modelopt.torch.opt as mto + +# Enable HuggingFace checkpointing for modelopt quantized models +mto.enable_huggingface_checkpointing() + + +def main(args): + if args.pt_checkpoint_path: + # Load base model first, then restore quantization state from mto.save() checkpoint + print("Loading base model from Qwen/Qwen3-Omni-30B-A3B-Thinking...") + model = Qwen3OmniMoeForConditionalGeneration.from_pretrained( + "Qwen/Qwen3-Omni-30B-A3B-Thinking", + torch_dtype="auto", + device_map="auto", + attn_implementation="flash_attention_2", + trust_remote_code=True, + ) + print(f"Restoring quantization state from {args.pt_checkpoint_path}...") + model = mto.restore(model, args.pt_checkpoint_path) + else: + # Load from HF checkpoint exported with export_hf_checkpoint() + print(f"Loading quantized model from {args.hf_checkpoint_path}...") + model = Qwen3OmniMoeForConditionalGeneration.from_pretrained( + args.hf_checkpoint_path, + torch_dtype="auto", + device_map="auto", + attn_implementation="flash_attention_2", + trust_remote_code=True, + ) + + model.disable_talker() + + print("Loading processor...") + processor = Qwen3OmniMoeProcessor.from_pretrained( + "Qwen/Qwen3-Omni-30B-A3B-Thinking", + trust_remote_code=True, + ) + + # Build conversation with user prompt + prompt = args.prompt or "What is the capital of France?" + conversation = [{"role": "user", "content": [{"type": "text", "text": prompt}]}] + conversations = [conversation] + + # Set whether to use audio in video + use_audio_in_video = True + + # Preparation for inference + texts = processor.apply_chat_template( + conversations, add_generation_prompt=True, tokenize=False, enable_thinking=False + ) + print(f"Texts: {texts}") + audios, images, videos = process_mm_info(conversations, use_audio_in_video=use_audio_in_video) + + inputs = processor( + text=texts, + audio=audios, + images=images, + videos=videos, + return_tensors="pt", + padding=True, + use_audio_in_video=use_audio_in_video, + ) + inputs = inputs.to(model.device).to(model.dtype) + + print(f"\nPrompt: {prompt}") + print("Generating...") + + start_time = time.time() + with torch.no_grad(): + text_ids, _ = model.generate( + **inputs, + thinker_return_dict_in_generate=True, + use_audio_in_video=use_audio_in_video, + max_new_tokens=args.max_new_tokens, + return_audio=False, + ) + end_time = time.time() + print(f"Time taken for generation: {end_time - start_time:.2f} seconds") + + # Decode the generated tokens + generated_text = processor.batch_decode( + text_ids.sequences[:, inputs["input_ids"].shape[1] :], + skip_special_tokens=True, + clean_up_tokenization_spaces=False, + ) + + print(f"\nGenerated: {generated_text[0]}") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Run quantized Qwen3Omni model") + parser.add_argument( + "--hf_checkpoint_path", + type=str, + default=None, + help="Path to the export_hf_checkpoint() quantized checkpoint directory", + ) + parser.add_argument( + "--pt_checkpoint_path", + type=str, + default=None, + help="Path to the mto.save() checkpoint file", + ) + parser.add_argument( + "--prompt", + type=str, + default=None, + help="Text prompt for generation", + ) + parser.add_argument( + "--max_new_tokens", + type=int, + default=512, + help="Maximum new tokens to generate", + ) + + args = parser.parse_args() + + # Validate arguments + if not args.hf_checkpoint_path and not args.pt_checkpoint_path: + parser.error("Either --hf_checkpoint_path or --pt_checkpoint_path must be provided") + + main(args) diff --git a/examples/llm_ptq/run_qwen_vllm.py b/examples/llm_ptq/run_qwen_vllm.py new file mode 100644 index 0000000000..f5f775d4d9 --- /dev/null +++ b/examples/llm_ptq/run_qwen_vllm.py @@ -0,0 +1,153 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Qwen3-Omni-30B-A3B text inference with vLLM. + +Usage: + python qwen3_omni_vllm.py + python qwen3_omni_vllm.py --model /path/to/model --tp 4 +""" + +from __future__ import annotations + +import argparse +import os +import shutil + +# import vllm.model_executor.parameter as vllm_param +from huggingface_hub import snapshot_download +from transformers import Qwen3OmniMoeProcessor +from vllm import LLM, SamplingParams + +MODEL_ID = "Qwen/Qwen3-Omni-30B-A3B-Thinking" + + +# # Debug patch to identify which weights cause shape mismatch +# def _patch_weight_loader_for_debug(): +# """Monkey-patch vLLM weight loader to print debug info on shape mismatch.""" +# original_load_column_parallel = vllm_param.ModelWeightParameter.load_column_parallel_weight + +# def debug_load_column_parallel(self, loaded_weight): +# print(f"Loading param: {getattr(self, 'name', getattr(self, '_name', repr(self)))}") +# print(f" Parameter shape (expected): {self.data.shape}") +# print(f" Loaded weight shape (got): {loaded_weight.shape}") + +# return original_load_column_parallel(self, loaded_weight) + +# vllm_param.ModelWeightParameter.load_column_parallel_weight = debug_load_column_parallel +# print("DEBUG: Patched vLLM weight loader to print shape mismatch info") + + +# _patch_weight_loader_for_debug() + +# Files needed for tokenizer/processor that vLLM loads from model path +TOKENIZER_FILES = [ + "vocab.json", + "merges.txt", + "tokenizer.json", + "tokenizer_config.json", + "special_tokens_map.json", + "preprocessor_config.json", + "chat_template.json", +] + + +def ensure_tokenizer_files(model_path: str, source_model_id: str) -> None: + """Copy tokenizer files from HF model to local quantized model dir if missing.""" + if not os.path.isdir(model_path): + return # Not a local path, nothing to do + + # Check if tokenizer files are missing + missing_files = [f for f in TOKENIZER_FILES if not os.path.exists(os.path.join(model_path, f))] + if not missing_files: + return + + print(f"Copying missing tokenizer files from {source_model_id}...") + # Download only tokenizer files from HF + cache_dir = snapshot_download( + source_model_id, + allow_patterns=TOKENIZER_FILES, + ) + + for fname in TOKENIZER_FILES: + src = os.path.join(cache_dir, fname) + dst = os.path.join(model_path, fname) + if os.path.exists(src) and not os.path.exists(dst): + shutil.copy2(src, dst) + print(f" Copied {fname}") + + +def main(): + parser = argparse.ArgumentParser(description="Run Qwen3-Omni text inference with vLLM") + parser.add_argument("--model", default=MODEL_ID, help="Model ID or path") + parser.add_argument("--tp", type=int, default=1, help="Tensor parallel size") + parser.add_argument("--max-model-len", type=int, default=32768, help="Max model length") + + args = parser.parse_args() + + # Load processor for chat template + processor = Qwen3OmniMoeProcessor.from_pretrained(MODEL_ID) + + # Text-only conversations + conversations = [ + [ + { + "role": "user", + "content": [{"type": "text", "text": "What are the key features of Qwen3-Omni?"}], + } + ], + ] + + # Apply chat template with thinking disabled + texts = processor.apply_chat_template( + conversations, + add_generation_prompt=True, + tokenize=False, + enable_thinking=False, + ) + + # Process multimodal info (returns empty for text-only) + # audios, images, videos = process_mm_info(conversations, use_audio_in_video=False) + + # Ensure tokenizer files exist in local model dir (vLLM loads processor from model path) + ensure_tokenizer_files(args.model, MODEL_ID) + + print(f"Loading model: {args.model}") + llm = LLM( + model=args.model, + tokenizer=MODEL_ID, # Always use original tokenizer from HF + tensor_parallel_size=args.tp, + max_model_len=args.max_model_len, + trust_remote_code=True, + quantization="modelopt_fp4", + ) + + sampling_params = SamplingParams( + temperature=0.7, + top_p=0.9, + max_tokens=512, + ) + + print("Running inference...") + outputs = llm.generate(texts, sampling_params) + + for output in outputs: + generated_text = output.outputs[0].text + print("-" * 80) + print(f"Generated: {generated_text}") + + +if __name__ == "__main__": + main() diff --git a/modelopt/torch/export/model_utils.py b/modelopt/torch/export/model_utils.py index 5a24429ad7..4e08f3dccb 100755 --- a/modelopt/torch/export/model_utils.py +++ b/modelopt/torch/export/model_utils.py @@ -31,6 +31,7 @@ "ChatGLM": "chatglm", "Qwen3Moe": "qwen3moe", "Qwen3Next": "qwen3next", + "Qwen3OmniMoeForConditionalGeneration": "qwen3omni", "QWen": "qwen", "RecurrentGemma": "recurrentgemma", "Gemma3": "gemma3", diff --git a/modelopt/torch/export/unified_export_hf.py b/modelopt/torch/export/unified_export_hf.py index 011af533dd..6136ae39df 100644 --- a/modelopt/torch/export/unified_export_hf.py +++ b/modelopt/torch/export/unified_export_hf.py @@ -300,29 +300,43 @@ def llm_dummy_forward(): [1, model.config.num_mel_bins, feature_extractor.nb_max_frames], dtype=model.dtype ).to(model.device) - if getattr(model.config, "is_encoder_decoder", False): - # For encoder-decoder models, we need to pass both the encoder and decoder input ids - model(fake_input, decoder_input_ids=decoder_fake_input) - elif is_vl_model and "nemotron" in model_type: - # For Nemotron VL models, try to run optimization on just the language model part - language_model_lineage = get_language_model_from_vl(model) - - if language_model_lineage is not None: - # Run optimization on just the language model with the same input format as regular LLMs - # Use the same fake_input tensor that regular LLMs use - language_model = language_model_lineage[-1] - print( - f"Running optimization on language model with fake_input shape: {fake_input.shape}" - ) - language_model(fake_input) + with set_quantizer_by_cfg_context(model, {"*": {"enable": False}}): + if getattr(model.config, "is_encoder_decoder", False): + # For encoder-decoder models, we need to pass both the encoder and decoder input ids + model(fake_input, decoder_input_ids=decoder_fake_input) + elif is_vl_model and "nemotron" in model_type: + # For Nemotron VL models, try to run optimization on just the language model part + language_model_lineage = get_language_model_from_vl(model) + + if language_model_lineage is not None: + # Run optimization on just the language model with the same input format as regular LLMs + # Use the same fake_input tensor that regular LLMs use + language_model = language_model_lineage[-1] + print( + f"Running optimization on language model with fake_input shape: {fake_input.shape}" + ) + language_model(fake_input) + else: + raise ValueError( + f"Cannot extract language_model from Nemotron VL model (type: {model_type}). " + "This is required for requantization/resmoothing optimization. " + "Please ensure the model architecture is supported or file an issue." + ) + elif "qwen3omni" in model_type: + # For Qwen3Omni, run on the thinker (language model) component + # The model has structure: model.thinker.model.layers.* + if hasattr(model, "thinker"): + print( + f"Running optimization on Qwen3Omni thinker with fake_input shape: {fake_input.shape}" + ) + model.thinker(fake_input) + else: + raise ValueError( + f"Cannot extract thinker from Qwen3Omni model (type: {model_type}). " + "This is required for requantization/resmoothing optimization." + ) else: - raise ValueError( - f"Cannot extract language_model from Nemotron VL model (type: {model_type}). " - "This is required for requantization/resmoothing optimization. " - "Please ensure the model architecture is supported or file an issue." - ) - else: - model(fake_input) + model(fake_input) input_to_linear, output_to_layernorm = _collect_shared_input_modules( model, llm_dummy_forward, collect_layernorms=True @@ -380,6 +394,19 @@ def _export_quantized_weight( weight_quantizer: TensorQuantizer | SequentialQuantizer = getattr( sub_module, quantizer_attrs.weight_quantizer ) + + # Skip export if weight quantizer is disabled or has no amax (not calibrated) + if not _is_enabled_quantizer(weight_quantizer): + return + + # Check if weight quantizer has calibrated amax + def _has_amax(quantizer): + if isinstance(quantizer, SequentialQuantizer): + return any(hasattr(q, "_amax") and q._amax is not None for q in quantizer) + return hasattr(quantizer, "_amax") and quantizer._amax is not None + + if not _has_amax(weight_quantizer): + return input_quantizer: TensorQuantizer | SequentialQuantizer | None = getattr( sub_module, quantizer_attrs.input_quantizer, None ) @@ -543,6 +570,7 @@ def _process_quantized_modules( model: nn.Module, dtype: torch.dtype, is_modelopt_qlora: bool = False, + pack_weights: bool = True, ) -> None: """Process all quantized modules in model, export weights in-place. @@ -555,6 +583,7 @@ def _process_quantized_modules( dtype: The data type for weight conversion. is_modelopt_qlora: Whether the model is a modelopt-trained QLoRA model. If True, modules with base_layer attribute are skipped. + pack_weights: Whether to pack quantized weights. """ fsdp_module_to_reshard = None @@ -577,8 +606,9 @@ def _process_quantized_modules( sub_module.unpack_weight() if get_quantization_format(sub_module) != QUANTIZATION_NONE: if is_quantlinear(sub_module): - with fsdp2_aware_weight_update(model, sub_module, reshard=False): - _export_quantized_weight(sub_module, dtype) + if pack_weights: + with fsdp2_aware_weight_update(model, sub_module, reshard=False): + _export_quantized_weight(sub_module, dtype) elif ( "Llama4TextExperts" in type(sub_module).__name__ or "GptOssExperts" in type(sub_module).__name__ @@ -595,13 +625,18 @@ def _process_quantized_modules( quantizer_attrs=["gate_up_proj_input_quantizer", "down_proj_input_quantizer"], ) # Export the quantized weights - with fsdp2_aware_weight_update(model, sub_module, reshard=False): - for weight_name in ["gate_up_proj", "down_proj"]: - _export_quantized_weight(sub_module, dtype, weight_name) + if pack_weights: + with fsdp2_aware_weight_update(model, sub_module, reshard=False): + for weight_name in ["gate_up_proj", "down_proj"]: + _export_quantized_weight(sub_module, dtype, weight_name) def _export_transformers_checkpoint( - model: nn.Module, dtype: torch.dtype | None = None, is_modelopt_qlora: bool = False, **kwargs + model: nn.Module, + dtype: torch.dtype | None = None, + is_modelopt_qlora: bool = False, + pack_weights: bool = True, + **kwargs, ) -> tuple[dict[str, Any], dict[str, Any]]: """Exports the torch model to the packed checkpoint with original HF naming. @@ -611,6 +646,7 @@ def _export_transformers_checkpoint( model: the full torch model to export. The actual quantized model may be a submodule. dtype: the weights data type to export the unquantized layers or the default model data type if None. accelerator: the accelerator instance in case of distributed export setup. + pack_weights: whether to pack quantized weights (False keeps original shapes for HF reload). Returns: post_state_dict: Dict containing quantized weights @@ -695,7 +731,7 @@ def _export_transformers_checkpoint( quant_config = get_quant_config(model, is_modelopt_qlora=is_modelopt_qlora) # Process all quantized modules and export weights - _process_quantized_modules(model, dtype, is_modelopt_qlora) + _process_quantized_modules(model, dtype, is_modelopt_qlora, pack_weights) if accelerator is not None: # Gather state_dict from all ranks @@ -964,7 +1000,12 @@ def export_hf_checkpoint( return try: - post_state_dict, hf_quant_config = _export_transformers_checkpoint(model, dtype) + # Packed weights are only for TRT-LLM consumption + # Set this to true if you want to save the weights in the original precision + pack_weights = True + post_state_dict, hf_quant_config = _export_transformers_checkpoint( + model, dtype, pack_weights=pack_weights + ) if hf_quant_config is not None: # Save hf_quant_config.json for backward compatibility @@ -977,6 +1018,16 @@ def export_hf_checkpoint( if getattr(model, "hf_quantizer", None) is not None: model.hf_quantizer = None + # Fix generation_config conflicts before saving + # Some models have temperature/top_p/top_k set but do_sample=False which causes validation errors + if hasattr(model, "generation_config") and model.generation_config is not None: + gen_config = model.generation_config + if not getattr(gen_config, "do_sample", True): + # Remove sampling-related params when do_sample is False + for attr in ["temperature", "top_p", "top_k"]: + if hasattr(gen_config, attr): + setattr(gen_config, attr, None) + # Save model model.save_pretrained( export_dir, state_dict=post_state_dict, save_modelopt_state=save_modelopt_state diff --git a/modelopt/torch/quantization/plugins/huggingface.py b/modelopt/torch/quantization/plugins/huggingface.py index a29d7c7549..54b98052f1 100644 --- a/modelopt/torch/quantization/plugins/huggingface.py +++ b/modelopt/torch/quantization/plugins/huggingface.py @@ -796,6 +796,24 @@ def unpack_weight(self): except ImportError: pass +# Uncomment to forward tokens to all MoE experts for full calibration. +try: + from transformers.models.qwen3_omni_moe.modeling_qwen3_omni_moe import ( + Qwen3OmniMoeTalkerTextSparseMoeBlock, + Qwen3OmniMoeThinkerTextSparseMoeBlock, + ) + + if Qwen3OmniMoeTalkerTextSparseMoeBlock not in QuantModuleRegistry: + QuantModuleRegistry.register( + {Qwen3OmniMoeTalkerTextSparseMoeBlock: "hf.Qwen3OmniMoeTalkerTextSparseMoeBlock"} + )(_QuantSparseMoe) + if Qwen3OmniMoeThinkerTextSparseMoeBlock not in QuantModuleRegistry: + QuantModuleRegistry.register( + {Qwen3OmniMoeThinkerTextSparseMoeBlock: "hf.Qwen3OmniMoeThinkerTextSparseMoeBlock"} + )(_QuantSparseMoe) +except ImportError: + pass + class _QuantGptOssExperts(_QuantFunctionalMixin): """Quantized wrapper for `transformers.GptOssExperts`. diff --git a/modelopt/torch/utils/__init__.py b/modelopt/torch/utils/__init__.py index 3ae385ac66..b909609c45 100644 --- a/modelopt/torch/utils/__init__.py +++ b/modelopt/torch/utils/__init__.py @@ -26,4 +26,5 @@ from .perf import * from .regex import * from .tensor import * +from .video_dataset_utils import * from .vlm_dataset_utils import * diff --git a/modelopt/torch/utils/dataset_utils.py b/modelopt/torch/utils/dataset_utils.py index 042e74ba5b..e68ee4f998 100644 --- a/modelopt/torch/utils/dataset_utils.py +++ b/modelopt/torch/utils/dataset_utils.py @@ -98,6 +98,7 @@ "create_forward_loop", "get_dataset_dataloader", "get_max_batch_size", + "get_qwen3omni_text_dataloader", "get_supported_datasets", ] @@ -243,6 +244,88 @@ def get_dataset_dataloader( return calib_dataloader +def get_qwen3omni_text_dataloader( + dataset_name: str | list[str] = "cnn_dailymail", + processor=None, + batch_size: int = 1, + num_samples: int | list[int] = 512, + max_sample_length: int = 512, + device: str | None = None, +) -> DataLoader: + """Get a text-only dataloader for Qwen3-Omni with proper conversation template applied. + + This function applies the Qwen3-Omni chat template to text samples before tokenization, + which is required for proper calibration of Qwen3-Omni models with text-only datasets. + + See: https://huggingface.co/Qwen/Qwen3-Omni-30B-A3B-Thinking + + Args: + dataset_name: Name of the dataset(s) to load. + processor: Qwen3OmniTextProcessor instance wrapping the Qwen3OmniMoeProcessor. + batch_size: Batch size of the returned dataloader. + num_samples: Number of samples from the dataset. + max_sample_length: Maximum length of a sample (for truncation). + device: Target device for the returned dataloader. + + Returns: + A DataLoader with properly formatted inputs for Qwen3-Omni. + """ + assert processor is not None, "Please provide a Qwen3OmniTextProcessor." + + if isinstance(num_samples, int): + num_samples = [num_samples] + + if isinstance(dataset_name, str): + dataset_name = [dataset_name] + + assert len(dataset_name) == len(num_samples), ( + "dataset_name and num_samples must be the same length" + ) + + # Get raw text samples + all_samples = [] + for ds_name, num_sample in zip(dataset_name, num_samples): + samples = _get_dataset_samples(ds_name, num_sample) + all_samples.extend(samples) + + # Preprocess each sample with the conversation template + processed_samples = [] + for text in all_samples: + # Apply conversation template and tokenize + values = processor.preprocess_function(text) + + # Convert to lists for dataset compatibility + sample_dict = {} + for key, val in values.items(): + if val is not None and hasattr(val, "tolist"): + sample_dict[key] = val.tolist() + elif val is not None: + sample_dict[key] = val + processed_samples.append(sample_dict) + + # Create dataset + class _Qwen3OmniTextDataset(torch.utils.data.Dataset): + def __init__(self, samples): + self.samples = samples + + def __getitem__(self, idx): + return self.samples[idx] + + def __len__(self): + return len(self.samples) + + dataset = _Qwen3OmniTextDataset(processed_samples) + + calib_dataloader = DataLoader( + dataset, + batch_size=batch_size, + shuffle=False, + collate_fn=processor.collate_function, + ) + + return calib_dataloader + + def get_supported_datasets() -> list[str]: """Retrieves a list of datasets supported. @@ -282,8 +365,8 @@ def _get_free_gpu_mem(): torch.cuda.empty_cache() free_mem_before, max_allocated_before = _get_free_gpu_mem() - is_enc_dec = model_type_is_enc_dec(model) - infer_method = model.generate if is_enc_dec else model.forward + use_generate = _should_use_generate(model) + infer_method = model.generate if use_generate else model.forward if sample_input_single_batch is None: sample_input_single_batch = ( @@ -349,11 +432,15 @@ def _process_batch(batch_data, infer_method, max_working_batch_size=None): Returns: The maximum batch size that worked successfully """ - assert all(torch.is_tensor(data) or data is None for data in batch_data.values()), ( - "batch_data values must be tensors" + # Separate tensor values from scalar parameters (like max_new_tokens) + tensor_data = {k: v for k, v in batch_data.items() if torch.is_tensor(v) or v is None} + scalar_data = {k: v for k, v in batch_data.items() if not torch.is_tensor(v) and v is not None} + + assert all(torch.is_tensor(data) or data is None for data in tensor_data.values()), ( + "tensor_data values must be tensors" ) # Get the batch size of current data - batch_size = batch_data[next(iter(batch_data.keys()))].shape[0] + batch_size = tensor_data[next(iter(tensor_data.keys()))].shape[0] # If we know a smaller batch size works, preemptively split if max_working_batch_size is not None and batch_size > max_working_batch_size: @@ -361,11 +448,13 @@ def _process_batch(batch_data, infer_method, max_working_batch_size=None): for i in range(0, batch_size, max_working_batch_size): end_idx = min(i + max_working_batch_size, batch_size) split_data = {} - for key in batch_data: - if batch_data[key] is None: + for key in tensor_data: + if tensor_data[key] is None: split_data[key] = None else: - split_data[key] = batch_data[key][i:end_idx, ...] + split_data[key] = tensor_data[key][i:end_idx, ...] + # Add back scalar data (non-tensor params like max_new_tokens) + split_data.update(scalar_data) max_working_batch_size = _process_batch( split_data, infer_method, max_working_batch_size @@ -392,8 +481,11 @@ def _process_batch(batch_data, infer_method, max_working_batch_size=None): # Split the batch in half mid = (batch_size + 1) // 2 warn(f"CUDA out of memory with batch size {batch_size}, trying with batch size {mid}") - split_data_1 = {key: batch_data[key][:mid, ...] for key in batch_data} - split_data_2 = {key: batch_data[key][mid:, ...] for key in batch_data} + split_data_1 = {key: tensor_data[key][:mid, ...] for key in tensor_data} + split_data_2 = {key: tensor_data[key][mid:, ...] for key in tensor_data} + # Add back scalar data (non-tensor params like max_new_tokens) + split_data_1.update(scalar_data) + split_data_2.update(scalar_data) # Recursively process each half and track max working batch size max_working_batch_size = _process_batch(split_data_1, infer_method) @@ -411,11 +503,14 @@ def _forward_loop(model: torch.nn.Module, dataloader: DataLoader) -> None: dataloader: DataLoader containing the batched input data """ with torch.no_grad(): - is_enc_dec = model_type_is_enc_dec(model) - infer_method = model.generate if is_enc_dec else model.forward + use_generate = _should_use_generate(model) + infer_method = model.generate if use_generate else model.forward max_working_batch_size = None # Initialize max working batch size as None for _, data in enumerate(tqdm(dataloader)): + # For generate(), add max_new_tokens to prevent indefinite generation during calibration + if use_generate: + data["max_new_tokens"] = 1 # Process batch and update max working batch size max_working_batch_size = _process_batch(data, infer_method, max_working_batch_size) @@ -493,3 +588,15 @@ def create_forward_loop( def model_type_is_enc_dec(model): enc_dec_model_list = ["t5", "bart", "whisper"] return any(model_name in model.__class__.__name__.lower() for model_name in enc_dec_model_list) + + +def _should_use_generate(model): + """Check if model should use generate() instead of forward() for calibration. + + Returns True for: + - Encoder-decoder models (t5, bart, whisper) + - Conditional generation models that don't support standard forward() (qwen3omni) + """ + generate_model_list = ["qwen3omni"] + model_name = model.__class__.__name__.lower() + return model_type_is_enc_dec(model) or any(name in model_name for name in generate_model_list) diff --git a/modelopt/torch/utils/image_processor.py b/modelopt/torch/utils/image_processor.py index 6374642e3d..07deca7fc4 100644 --- a/modelopt/torch/utils/image_processor.py +++ b/modelopt/torch/utils/image_processor.py @@ -110,3 +110,174 @@ def collate_function(self, batch): ).to(self.device) return batch[0] + + +class Qwen3OmniTextProcessor(BaseImageProcessor): + """Text-only processor for Qwen3-Omni that applies proper conversation template. + + This processor wraps raw text in the Qwen3-Omni conversation format and applies + the chat template before tokenization. Use this for text-only calibration datasets. + + See: https://huggingface.co/Qwen/Qwen3-Omni-30B-A3B-Thinking + """ + + def __init__(self, processor, device="auto", dtype=None): + """Constructor. + + Args: + processor: The Qwen3OmniMoeProcessor (from AutoProcessor.from_pretrained). + device: Device to move tensors to. + dtype: dtype for float tensors (e.g., torch.bfloat16). If None, uses default. + """ + super().__init__(processor, device) + self.dtype = dtype + + def preprocess_function(self, text: str) -> dict: + """Preprocess a single text sample by applying conversation template. + + Args: + text: Raw text string from dataset. + + Returns: + Dictionary with tokenized inputs. + """ + # Build conversation in Qwen format (text-only) + conversation = [{"role": "user", "content": [{"type": "text", "text": text}]}] + formatted_text = self.tokenizer.apply_chat_template( + conversation, add_generation_prompt=True, tokenize=False, enable_thinking=False + ) + + # Tokenize with the processor (no multimodal inputs) + values = self.tokenizer( + text=formatted_text, + audio=None, + images=None, + videos=None, + return_tensors="pt", + padding=True, + ) + + return values + + def collate_function(self, batch): + """Collate function to process text inputs during data loading.""" + result = {} + first = batch[0] + + if "input_ids" in first and first["input_ids"] is not None: + result["input_ids"] = torch.LongTensor(first["input_ids"]).to(self.device) + if "attention_mask" in first and first["attention_mask"] is not None: + result["attention_mask"] = torch.LongTensor(first["attention_mask"]).to(self.device) + + return result + + +class Qwen3OmniImageProcessor(BaseImageProcessor): + """Image processor for Qwen3-Omni multimodal model.""" + + def __init__(self, tokenizer, device="auto", use_audio_in_video=False): + """Constructor.""" + super().__init__(tokenizer, device) + self.use_audio_in_video = use_audio_in_video + # Try to import qwen_omni_utils for multimodal processing + try: + from qwen_omni_utils import process_mm_info + + self.process_mm_info = process_mm_info + except ImportError: + raise ImportError( + "qwen_omni_utils is required for Qwen3OmniImageProcessor. " + "Please install it from https://github.com/QwenLM/Qwen3-Omni" + ) + + def preprocess_function(self, examples): + """Preprocess function for Qwen3-Omni.""" + question = examples.get("question", "Describe this image.") + + # Build conversation in Qwen format + content = [] + if examples.get("image") is not None: + content.append({"type": "image", "image": examples["image"]}) + if examples.get("audio") is not None: + content.append({"type": "audio", "audio": examples["audio"]}) + if examples.get("video") is not None: + content.append({"type": "video", "video": examples["video"]}) + content.append({"type": "text", "text": question}) + + conversation = [{"role": "user", "content": content}] + text = self.tokenizer.apply_chat_template( + conversation, add_generation_prompt=True, tokenize=False, enable_thinking=False + ) + + # Extract multimodal info using qwen_omni_utils + audios, images, videos = self.process_mm_info( + conversation, use_audio_in_video=self.use_audio_in_video + ) + + # Process inputs with the processor + values = self.tokenizer( + text=text, + audio=audios, + images=images, + videos=videos, + return_tensors="pt", + padding=True, + use_audio_in_video=self.use_audio_in_video, + ) + + # Define all possible keys to ensure consistent schema for Arrow serialization + all_keys = [ + "input_ids", + "attention_mask", + "pixel_values", + "image_grid_thw", + "audio_features", + "audio_feature_lens", + "video_grid_thw", + ] + + # Convert tensors to lists for Arrow serialization compatibility + # Tensor conversion back happens in collate_function + result = dict.fromkeys(all_keys) # Initialize all keys to None + for key, val in values.items(): + if val is not None and hasattr(val, "tolist"): + result[key] = val.tolist() + elif val is not None: + result[key] = val + + return result + + def collate_function(self, batch): + """Collate function to process inputs during data loading.""" + result = {} + + # Take first item from batch (batch_size handling) + first = batch[0] + + # Convert lists to tensors and move to device + if "input_ids" in first and first["input_ids"] is not None: + result["input_ids"] = torch.LongTensor(first["input_ids"]).to(self.device) + if "attention_mask" in first and first["attention_mask"] is not None: + result["attention_mask"] = torch.LongTensor(first["attention_mask"]).to(self.device) + + # Handle pixel values for images + if first.get("pixel_values") is not None: + result["pixel_values"] = torch.tensor(first["pixel_values"]).to(self.device) + + # Handle image grid thw (tile height width info) + if first.get("image_grid_thw") is not None: + result["image_grid_thw"] = torch.LongTensor(first["image_grid_thw"]).to(self.device) + + # Handle audio features if present + if first.get("audio_feature_lens") is not None: + result["audio_feature_lens"] = torch.LongTensor(first["audio_feature_lens"]).to( + self.device + ) + if first.get("audio_features") is not None: + result["audio_features"] = torch.tensor(first["audio_features"]).to(self.device) + + # Handle video features if present + if first.get("video_grid_thw") is not None: + result["video_grid_thw"] = torch.LongTensor(first["video_grid_thw"]).to(self.device) + + return result diff --git a/modelopt/torch/utils/video_dataset_utils.py b/modelopt/torch/utils/video_dataset_utils.py new file mode 100644 index 0000000000..e022d7e24f --- /dev/null +++ b/modelopt/torch/utils/video_dataset_utils.py @@ -0,0 +1,332 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Utility functions for getting samples and forward loop function for video datasets.""" + +import os +import tempfile +from typing import Any + +import torch +from torch.utils.data import DataLoader + +from .image_processor import BaseImageProcessor + +# Use dict to store the config for each dataset. +SUPPORTED_VIDEO_DATASET_CONFIG: dict[str, dict[str, Any]] = { + "finevideo": { + "config": {"path": "HuggingFaceFV/finevideo", "split": "train", "streaming": True} + }, +} + +__all__ = [ + "Qwen3OmniVideoProcessor", + "get_supported_video_datasets", + "get_video_dataset_dataloader", +] + + +def _get_video_dataset(dataset_name: str, num_samples: int): + """Load a portion of train dataset with the dataset name and a given size. + + Args: + dataset_name: Name of the dataset to load. + num_samples: Number of samples to load from the dataset. + + Returns: + A hugging face Dataset. + """ + if dataset_name in SUPPORTED_VIDEO_DATASET_CONFIG: + from datasets import Dataset, load_dataset + + config = SUPPORTED_VIDEO_DATASET_CONFIG[dataset_name]["config"] + is_streaming = config.get("streaming", False) + + dataset = load_dataset(**config) + + if is_streaming: + # For streaming datasets, use take() and convert to list then Dataset + samples = list(dataset.take(num_samples)) + return Dataset.from_list(samples) + else: + return dataset.select(range(num_samples)) + else: + raise NotImplementedError( + f"dataset {dataset_name} is not supported. Please use one of the following:" + f" {get_supported_video_datasets()}." + ) + + +def get_supported_video_datasets() -> list[str]: + """Retrieves a list of video datasets supported. + + Returns: + A list of strings, where each string is the name of a supported dataset. + + Example usage: + + .. code-block:: python + + from modelopt.torch.utils import get_supported_video_datasets + + print("Supported video datasets:", get_supported_video_datasets()) + """ + return list(SUPPORTED_VIDEO_DATASET_CONFIG.keys()) + + +def get_video_dataset_dataloader( + dataset_name: str = "finevideo", + processor: "Qwen3OmniVideoProcessor" = None, + batch_size: int = 1, + num_samples: int = 512, + cache_dir: str | None = None, +) -> DataLoader: + """Get a dataloader with the dataset name and processor of the target model. + + Args: + dataset_name: Name of the dataset to load. + processor: Processor used for encoding video and text data. + batch_size: Batch size of the returned dataloader. + num_samples: Number of samples from the dataset. + cache_dir: Directory to cache the processed dataset. Defaults to a temp directory. + If the cache exists, it will be loaded instead of reprocessing. + + Returns: + An instance of dataloader. + """ + assert processor is not None, "Please provide a valid processor." + + # Default cache_dir to temp directory + if cache_dir is None: + cache_dir = os.path.join(tempfile.gettempdir(), "modelopt_video_dataset_cache") + + processed_dataset = None + + # Try to load from cache (use torch.save/load to avoid Arrow 32-bit offset overflow) + if cache_dir is not None: + cache_path = os.path.join(cache_dir, f"{dataset_name}_n{num_samples}_processed.pt") + if os.path.exists(cache_path): + try: + from datasets import Dataset + + processed_samples = torch.load(cache_path, weights_only=False) + processed_dataset = Dataset.from_list(processed_samples) + print(f"Loaded processed dataset from cache: {cache_path}") + except Exception as e: + print(f"Failed to load cache from {cache_path}: {e}. Reprocessing...") + processed_dataset = None + + # Process dataset if not loaded from cache + if processed_dataset is None: + from datasets import Dataset + + dataset = _get_video_dataset(dataset_name, num_samples=num_samples) + + # Process samples manually to avoid Arrow 32-bit offset overflow + # (dataset.map() uses Arrow internally which can't handle large nested lists) + processed_samples = [] + for i, sample in enumerate(dataset): + processed = processor.preprocess_function(sample) + processed_samples.append(processed) + if (i + 1) % 10 == 0: + print(f"Processed {i + 1}/{len(dataset)} samples...") + + processed_dataset = Dataset.from_list(processed_samples) + + # Save to cache using torch.save to avoid Arrow 32-bit offset overflow + if cache_dir is not None: + os.makedirs(cache_dir, exist_ok=True) + torch.save(processed_samples, cache_path) + print(f"Saved processed dataset to cache: {cache_path}") + + # Create DataLoader with the custom collate function + return DataLoader( + processed_dataset, + batch_size=batch_size, + shuffle=False, + collate_fn=processor.collate_function, + ) + + +class Qwen3OmniVideoProcessor(BaseImageProcessor): + """Video processor for Qwen3-Omni multimodal model with finevideo dataset support.""" + + def __init__(self, tokenizer, device="cuda", dtype=None, use_audio_in_video=True): + """Constructor. + + Args: + tokenizer: The Qwen3OmniMoeProcessor for tokenizing and processing inputs. + device: Device to move tensors to. + dtype: dtype for float tensors (e.g., torch.bfloat16). If None, uses default. + use_audio_in_video: Whether to extract and use audio from video files. + """ + super().__init__(tokenizer, device) + self.dtype = dtype + self.use_audio_in_video = use_audio_in_video + self._temp_dir = tempfile.mkdtemp(prefix="qwen3omni_video_") + self._video_counter = 0 + # Try to import qwen_omni_utils for multimodal processing + try: + from qwen_omni_utils import process_mm_info + + self.process_mm_info = process_mm_info + except ImportError: + raise ImportError( + "qwen_omni_utils is required for Qwen3OmniVideoProcessor. " + "Please install it from https://github.com/QwenLM/Qwen3-Omni" + ) + + def _save_video_bytes_to_file(self, video_bytes: bytes) -> str: + """Save video bytes to a temporary file and return the path. + + Args: + video_bytes: Raw video bytes (e.g., from finevideo's 'mp4' field). + + Returns: + Path to the temporary video file. + """ + video_path = os.path.join(self._temp_dir, f"video_{self._video_counter}.mp4") + self._video_counter += 1 + with open(video_path, "wb") as f: + f.write(video_bytes) + return video_path + + def preprocess_function(self, examples): + """Preprocess function for Qwen3-Omni with video support. + + Handles both standard video paths and raw video bytes (finevideo format). + """ + # Get question/prompt - finevideo has metadata in 'json' field + if "json" in examples and examples["json"] is not None: + metadata = examples["json"] + # Try to get a meaningful question from metadata + category = metadata.get("content_fine_category", "") + question = ( + f"Describe what is happening in this video in detail. Category hint: {category}" + ) + else: + question = examples.get("question", "Describe this video in detail.") + + # Build conversation in Qwen format + content = [] + + # Handle video - check for raw bytes (finevideo format) or path + video_path = None + if examples.get("mp4") is not None: + # finevideo format: raw video bytes in 'mp4' field + video_path = self._save_video_bytes_to_file(examples["mp4"]) + elif examples.get("video") is not None: + # Standard format: video path or URL + video_path = examples["video"] + + if video_path is not None: + content.append({"type": "video", "video": video_path}) + + content.append({"type": "text", "text": question}) + + conversation = [{"role": "user", "content": content}] + text = self.tokenizer.apply_chat_template( + conversation, add_generation_prompt=True, tokenize=False, enable_thinking=False + ) + + # Extract multimodal info using qwen_omni_utils + audios, images, videos = self.process_mm_info( + conversation, use_audio_in_video=self.use_audio_in_video + ) + + # Process inputs with the processor + values = self.tokenizer( + text=text, + audio=audios, + images=images, + videos=videos, + return_tensors="pt", + padding=True, + use_audio_in_video=self.use_audio_in_video, + ) + # Define all possible keys to ensure consistent schema for Arrow serialization + all_keys = [ + "input_ids", + "attention_mask", + "pixel_values_videos", + "video_grid_thw", + "video_second_per_grid", + "feature_attention_mask", + "input_features", + ] + + # Convert tensors to lists for Arrow serialization compatibility + # Tensor conversion back happens in collate_function + result = dict.fromkeys(all_keys) # Initialize all keys to None + for key, val in values.items(): + if val is not None and hasattr(val, "tolist"): + result[key] = val.tolist() + elif val is not None: + result[key] = val + + return result + + def collate_function(self, batch): + """Collate function to process inputs during data loading.""" + result = {} + + # Take first item from batch (batch_size handling) + first = batch[0] + + # Convert lists to tensors and move to device + if first.get("input_ids") is not None: + result["input_ids"] = torch.LongTensor(first["input_ids"]).to(self.device) + if first.get("attention_mask") is not None: + result["attention_mask"] = torch.LongTensor(first["attention_mask"]).to(self.device) + + # Handle pixel values for video frames + if first.get("pixel_values_videos") is not None: + pv = torch.tensor(first["pixel_values_videos"]) + if self.dtype is not None: + pv = pv.to(self.dtype) + result["pixel_values_videos"] = pv.to(self.device) + + # Handle video grid thw (tile height width info) + if first.get("video_grid_thw") is not None: + result["video_grid_thw"] = torch.LongTensor(first["video_grid_thw"]).to(self.device) + + # Handle video second per grid (temporal info for rope) + if first.get("video_second_per_grid") is not None: + result["video_second_per_grid"] = torch.tensor(first["video_second_per_grid"]).to( + self.device + ) + + # Handle audio features if present + if first.get("feature_attention_mask") is not None: + result["feature_attention_mask"] = torch.LongTensor(first["feature_attention_mask"]).to( + self.device + ) + if first.get("input_features") is not None: + inp_feat = torch.tensor(first["input_features"]) + if self.dtype is not None: + inp_feat = inp_feat.to(self.dtype) + result["input_features"] = inp_feat.to(self.device) + + # Pass use_audio_in_video flag to model.generate() for Qwen3Omni + result["use_audio_in_video"] = self.use_audio_in_video + + return result + + def cleanup(self): + """Clean up temporary video files.""" + import shutil + + if os.path.exists(self._temp_dir): + shutil.rmtree(self._temp_dir) diff --git a/modelopt/torch/utils/vlm_dataset_utils.py b/modelopt/torch/utils/vlm_dataset_utils.py index 3f07c57715..196d452ebc 100644 --- a/modelopt/torch/utils/vlm_dataset_utils.py +++ b/modelopt/torch/utils/vlm_dataset_utils.py @@ -30,7 +30,7 @@ import torch from torch.utils.data import DataLoader -from .image_processor import MllamaImageProcessor +from .image_processor import BaseImageProcessor, MllamaImageProcessor from .nemotron_vlm_dataset_utils import NemotronTarPlusJsonlIterable, list_repo_files_cached # Use dict to store the config for each dataset. @@ -331,7 +331,7 @@ def get_supported_vlm_datasets() -> list[str]: def get_vlm_dataset_dataloader( dataset_name: str = "scienceqa", - processor: Any = None, + processor: BaseImageProcessor | Any = None, batch_size: int = 1, num_samples: int = 512, device: str | torch.device | None = None,