Skip to content

Commit e8d9b0e

Browse files
committed
Add a script to load and run the qwen3omni quantized checkpoint
1 parent 8a4cfac commit e8d9b0e

1 file changed

Lines changed: 128 additions & 0 deletions

File tree

Lines changed: 128 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,128 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-License-Identifier: Apache-2.0
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
17+
# SPDX-License-Identifier: Apache-2.0
18+
19+
"""Script to load and run a quantized Qwen3Omni model from mto checkpoint."""
20+
21+
import argparse
22+
import time
23+
24+
import torch
25+
from qwen_omni_utils import process_mm_info
26+
from transformers import Qwen3OmniMoeForConditionalGeneration, Qwen3OmniMoeProcessor
27+
28+
import modelopt.torch.opt as mto
29+
30+
31+
def main(args):
32+
print(f"Loading base model from {args.model_path}...")
33+
model = Qwen3OmniMoeForConditionalGeneration.from_pretrained(
34+
args.model_path,
35+
torch_dtype="auto",
36+
device_map="cuda",
37+
attn_implementation="flash_attention_2",
38+
trust_remote_code=True,
39+
)
40+
41+
print(f"Restoring quantized state from {args.checkpoint_path}...")
42+
model = mto.restore(model, args.checkpoint_path)
43+
44+
model.disable_talker()
45+
46+
print("Loading processor...")
47+
processor = Qwen3OmniMoeProcessor.from_pretrained(
48+
args.model_path,
49+
trust_remote_code=True,
50+
)
51+
52+
# Build conversation with user prompt
53+
prompt = args.prompt or "What is the capital of France?"
54+
conversation = [{"role": "user", "content": [{"type": "text", "text": f"{prompt}"}]}]
55+
conversations = [conversation]
56+
57+
# Set whether to use audio in video
58+
use_audio_in_video = True
59+
60+
# Preparation for inference
61+
texts = processor.apply_chat_template(conversations, add_generation_prompt=True, tokenize=False)
62+
audios, images, videos = process_mm_info(conversations, use_audio_in_video=use_audio_in_video)
63+
64+
inputs = processor(
65+
text=texts,
66+
audio=audios,
67+
images=images,
68+
videos=videos,
69+
return_tensors="pt",
70+
padding=True,
71+
use_audio_in_video=use_audio_in_video,
72+
)
73+
inputs = inputs.to(model.device).to(model.dtype)
74+
75+
print(f"\nPrompt: {prompt}")
76+
print("Generating...")
77+
78+
start_time = time.time()
79+
with torch.no_grad():
80+
text_ids, _ = model.generate(
81+
**inputs,
82+
thinker_return_dict_in_generate=True,
83+
use_audio_in_video=use_audio_in_video,
84+
max_new_tokens=args.max_new_tokens,
85+
return_audio=False,
86+
)
87+
end_time = time.time()
88+
print(f"Time taken for generation: {end_time - start_time:.2f} seconds")
89+
90+
# Decode the generated tokens
91+
generated_text = processor.batch_decode(
92+
text_ids.sequences[:, inputs["input_ids"].shape[1] :],
93+
skip_special_tokens=True,
94+
clean_up_tokenization_spaces=False,
95+
)
96+
97+
print(f"\nGenerated: {generated_text[0]}")
98+
99+
100+
if __name__ == "__main__":
101+
parser = argparse.ArgumentParser(description="Run quantized Qwen3Omni model")
102+
parser.add_argument(
103+
"--model_path",
104+
type=str,
105+
default="Qwen/Qwen3-Omni-30B-A3B-Instruct",
106+
help="Path to the base Qwen3Omni model (HF format)",
107+
)
108+
parser.add_argument(
109+
"--checkpoint_path",
110+
type=str,
111+
default="/home/scratch.arasane_hw/models/qwen3omni_nvfp4_qkv_disabled_text_bs512_calib512.pt",
112+
help="Path to the mto.save() quantized checkpoint",
113+
)
114+
parser.add_argument(
115+
"--prompt",
116+
type=str,
117+
default=None,
118+
help="Text prompt for generation",
119+
)
120+
parser.add_argument(
121+
"--max_new_tokens",
122+
type=int,
123+
default=512,
124+
help="Maximum new tokens to generate",
125+
)
126+
127+
args = parser.parse_args()
128+
main(args)

0 commit comments

Comments
 (0)