-
Notifications
You must be signed in to change notification settings - Fork 364
Add support for Qwen3-Omni-30B-A3B-Thinking #677
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Closed
Closed
Changes from all commits
Commits
Show all changes
19 commits
Select commit
Hold shift + click to select a range
7dca9d4
Add support for Qwen3-Omni-30B-A3B-Thinking
ajrasane f6ac2d3
Add the finevideo dataset for calibration
ajrasane 616cc1b
Add option to disable talker
ajrasane cf3bbb8
Add quantization configs for the model
ajrasane c700f91
Register Qwen3 thinker and talker sparse moe blocks in quant module
ajrasane 445ff6a
remove first_n and last_n configs
ajrasane de01666
Update quantization modes to stack on top of one another
ajrasane 7ef534a
Add a text processor for text datasets
ajrasane 7aa5aed
Disable Qwen3OmniMoe class registration
cjluo-nv fdad81a
Update logic to disable quantizers
ajrasane 8a4cfac
Add option to save the quantized checkpoint
ajrasane e8d9b0e
Add a script to load and run the qwen3omni quantized checkpoint
ajrasane e3337a0
Create a script to cache the processed dataset
ajrasane aa77565
Support export to hf format
ajrasane 4f92fbf
restore configs
ajrasane 3f12551
Added script to run with vllm
ajrasane a2ec8f3
Disable audio tower and visual encoder quantization
ajrasane 0b1d9ca
Add a flag to save the quant summary
ajrasane 690620f
Forward tokens to all experts
ajrasane File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,112 @@ | ||
| #!/usr/bin/env python3 | ||
| # SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. | ||
| # SPDX-License-Identifier: Apache-2.0 | ||
| # | ||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||
| # you may not use this file except in compliance with the License. | ||
| # You may obtain a copy of the License at | ||
| # | ||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||
| # | ||
| # Unless required by applicable law or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
|
|
||
| """Script to pre-generate processed video dataset for Qwen3-Omni quantization.""" | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. is this generation script qwen3_omni specific?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, I dont think we need to merge this in our codebase. Will document this separately. |
||
|
|
||
| import argparse | ||
| import os | ||
|
|
||
| import torch | ||
| from transformers import AutoProcessor | ||
|
|
||
| from modelopt.torch.utils.video_dataset_utils import ( | ||
| Qwen3OmniVideoProcessor, | ||
| get_video_dataset_dataloader, | ||
| ) | ||
|
|
||
|
|
||
| def main(): | ||
| parser = argparse.ArgumentParser(description="Generate processed video dataset cache") | ||
| parser.add_argument( | ||
| "--model-name", | ||
| type=str, | ||
| default="Qwen/Qwen3-Omni-30B-A3B-Thinking", | ||
| help="Model name or path for loading the processor", | ||
| ) | ||
| parser.add_argument( | ||
| "--dataset-name", | ||
| type=str, | ||
| default="finevideo", | ||
| help="Name of the video dataset to process", | ||
| ) | ||
| parser.add_argument( | ||
| "--num-samples", | ||
| type=int, | ||
| default=512, | ||
| help="Number of samples to process", | ||
| ) | ||
| parser.add_argument( | ||
| "--cache-dir", | ||
| type=str, | ||
| required=True, | ||
| help="Directory to save the processed dataset cache", | ||
| ) | ||
| parser.add_argument( | ||
| "--dtype", | ||
| type=str, | ||
| default="bfloat16", | ||
| choices=["float16", "bfloat16", "float32"], | ||
| help="Data type for processing", | ||
| ) | ||
| parser.add_argument( | ||
| "--no-audio", | ||
| action="store_true", | ||
| help="Disable audio extraction from videos", | ||
| ) | ||
| args = parser.parse_args() | ||
|
|
||
| use_audio = not args.no_audio | ||
|
|
||
| # Set dtype | ||
| dtype_map = { | ||
| "float16": torch.float16, | ||
| "bfloat16": torch.bfloat16, | ||
| "float32": torch.float32, | ||
| } | ||
| dtype = dtype_map[args.dtype] | ||
|
|
||
| print(f"Loading processor from {args.model_name}...") | ||
| hf_processor = AutoProcessor.from_pretrained(args.model_name, trust_remote_code=True) | ||
|
|
||
| print(f"Creating Qwen3OmniVideoProcessor (use_audio={use_audio}, dtype={args.dtype})...") | ||
| processor = Qwen3OmniVideoProcessor( | ||
| tokenizer=hf_processor, | ||
| device="cuda" if torch.cuda.is_available() else "cpu", | ||
| dtype=dtype, | ||
| use_audio_in_video=use_audio, | ||
| ) | ||
|
|
||
| print(f"Processing {args.num_samples} samples from {args.dataset_name}...") | ||
| print(f"Cache directory: {args.cache_dir}") | ||
|
|
||
| # This will process and save to cache | ||
| _ = get_video_dataset_dataloader( | ||
| dataset_name=args.dataset_name, | ||
| processor=processor, | ||
| batch_size=1, | ||
| num_samples=args.num_samples, | ||
| cache_dir=args.cache_dir, | ||
| ) | ||
|
|
||
| # Cleanup temp files | ||
| processor.cleanup() | ||
|
|
||
| cache_path = os.path.join(args.cache_dir, f"{args.dataset_name}_n{args.num_samples}_processed") | ||
| print(f"\nDone! Processed dataset saved to: {cache_path}") | ||
|
|
||
|
|
||
| if __name__ == "__main__": | ||
| main() | ||
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I feel this level of qformat is too detailed. Can you recommend one and use it for Qwen3 Omni?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The basic nvfp4 format works fine, we can use that for now. I will add these formats in a separate document for later reference.