Skip to content

Commit 2725797

Browse files
committed
Support export to hf format
Signed-off-by: ajrasane <131806219+ajrasane@users.noreply.github.com>
1 parent 06a90cd commit 2725797

5 files changed

Lines changed: 131 additions & 77 deletions

File tree

examples/llm_ptq/hf_ptq.py

Lines changed: 32 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -300,28 +300,11 @@ def main(args):
300300
attn_implementation=args.attn_implementation,
301301
)
302302

303+
# Uncomment this to load the model from a .pt file
304+
# model = mto.restore(model, "./qwen3_omni_30b_nvfp4/model.pt")
305+
# print("Qwen3Omni model restored from checkpoint")
306+
303307
quant_cfg = QUANT_CFG_CHOICES[args.qformat]
304-
# Qwen3 specific quantizer disabling patterns (thinker.model.layers only)
305-
if "qkv_disabled" in args.qformat:
306-
# Disable q_proj, k_proj, v_proj quantizers
307-
for proj in ["q_proj", "k_proj", "v_proj"]:
308-
quant_cfg["quant_cfg"][f"*thinker.model.layers.*.self_attn.{proj}*"] = {
309-
"enable": False
310-
}
311-
if "qkvo_disabled" in args.qformat:
312-
# Disable q_proj, k_proj, v_proj, o_proj quantizers
313-
for proj in ["o_proj"]:
314-
quant_cfg["quant_cfg"][f"*thinker.model.layers.*.self_attn.{proj}*"] = {
315-
"enable": False
316-
}
317-
if "first_and_last_n_disabled" in args.qformat:
318-
# Disable both first N and last N layers
319-
total_layers = 48
320-
n_layers_to_disable = 4
321-
for i in range(n_layers_to_disable):
322-
quant_cfg["quant_cfg"][f"*thinker.model.layers.{i}.*"] = {"enable": False}
323-
for i in range(total_layers - n_layers_to_disable, total_layers):
324-
quant_cfg["quant_cfg"][f"*thinker.model.layers.{i}.*"] = {"enable": False}
325308
else:
326309
assert args.qformat in QUANT_CFG_CHOICES, (
327310
f"Quantization format is not supported for low memory mode. Supported formats: {QUANT_CFG_CHOICES.keys()}"
@@ -606,6 +589,28 @@ def main(args):
606589
quant_cfg["quant_cfg"]["*radio*"] = {"enable": False}
607590
quant_cfg["quant_cfg"]["*visual*"] = {"enable": False}
608591

592+
# Qwen3 specific quantizer disabling patterns (thinker.model.layers only)
593+
if "qkv_disabled" in args.qformat:
594+
# Disable q_proj, k_proj, v_proj quantizers
595+
for proj in ["q_proj", "k_proj", "v_proj"]:
596+
quant_cfg["quant_cfg"][f"*thinker.model.layers.*.self_attn.{proj}*"] = {
597+
"enable": False
598+
}
599+
if "qkvo_disabled" in args.qformat:
600+
# Disable q_proj, k_proj, v_proj, o_proj quantizers
601+
for proj in ["o_proj"]:
602+
quant_cfg["quant_cfg"][f"*thinker.model.layers.*.self_attn.{proj}*"] = {
603+
"enable": False
604+
}
605+
if "first_and_last_n_disabled" in args.qformat:
606+
# Disable both first N and last N layers
607+
total_layers = 48
608+
n_layers_to_disable = 4
609+
for i in range(n_layers_to_disable):
610+
quant_cfg["quant_cfg"][f"*thinker.model.layers.{i}.*"] = {"enable": False}
611+
for i in range(total_layers - n_layers_to_disable, total_layers):
612+
quant_cfg["quant_cfg"][f"*thinker.model.layers.{i}.*"] = {"enable": False}
613+
609614
if not model_is_already_quantized or calibration_only:
610615
# Only run single sample for preview
611616
calib_batch = next(iter(calib_dataloader))
@@ -745,11 +750,11 @@ def output_decode(generated_ids, input_shape):
745750
assert model_type != "dbrx", f"Does not support export {model_type} without quantizaton"
746751
print(f"qformat: {args.qformat}. No quantization applied, export {device} model")
747752

748-
if model_type == "qwen3omni":
749-
print("Export of Qwen3Omni model is not supported yet. Saving .pt file instead.")
750-
os.makedirs(os.path.dirname(args.export_path), exist_ok=True)
751-
mto.save(model, args.export_path)
752-
return
753+
# Uncomment this to save the model as a .pt file
754+
# if model_type == "qwen3omni":
755+
# print("Export of Qwen3Omni model is not supported yet. Saving .pt file instead.")
756+
# os.makedirs(os.path.dirname(args.export_path), exist_ok=True)
757+
# mto.save(model, f"{args.export_path}/model.pt")
753758

754759
with torch.inference_mode():
755760
if model_type is None:
@@ -828,6 +833,7 @@ def output_decode(generated_ids, input_shape):
828833
export_hf_checkpoint(
829834
full_model,
830835
export_dir=export_path,
836+
save_modelopt_state=model_type == "qwen3omni",
831837
)
832838

833839
# Copy custom model files (Python files and JSON configs) if trust_remote_code is used

examples/llm_ptq/run_quantized_qwen3omni.py

Lines changed: 15 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
1717
# SPDX-License-Identifier: Apache-2.0
1818

19-
"""Script to load and run a quantized Qwen3Omni model from mto checkpoint."""
19+
"""Script to load and run a quantized Qwen3Omni model from export_hf_checkpoint."""
2020

2121
import argparse
2222
import time
@@ -27,38 +27,41 @@
2727

2828
import modelopt.torch.opt as mto
2929

30+
# Enable HuggingFace checkpointing for modelopt quantized models
31+
mto.enable_huggingface_checkpointing()
32+
3033

3134
def main(args):
32-
print(f"Loading base model from {args.model_path}...")
35+
print(f"Loading quantized model from {args.checkpoint_path}...")
3336
model = Qwen3OmniMoeForConditionalGeneration.from_pretrained(
34-
args.model_path,
37+
args.checkpoint_path,
3538
torch_dtype="auto",
36-
device_map="cuda",
39+
device_map="auto",
3740
attn_implementation="flash_attention_2",
3841
trust_remote_code=True,
3942
)
4043

41-
print(f"Restoring quantized state from {args.checkpoint_path}...")
42-
model = mto.restore(model, args.checkpoint_path)
43-
4444
model.disable_talker()
4545

4646
print("Loading processor...")
4747
processor = Qwen3OmniMoeProcessor.from_pretrained(
48-
args.model_path,
48+
"Qwen/Qwen3-Omni-30B-A3B-Thinking",
4949
trust_remote_code=True,
5050
)
5151

5252
# Build conversation with user prompt
5353
prompt = args.prompt or "What is the capital of France?"
54-
conversation = [{"role": "user", "content": [{"type": "text", "text": f"{prompt}"}]}]
54+
conversation = [{"role": "user", "content": [{"type": "text", "text": prompt}]}]
5555
conversations = [conversation]
5656

5757
# Set whether to use audio in video
5858
use_audio_in_video = True
5959

6060
# Preparation for inference
61-
texts = processor.apply_chat_template(conversations, add_generation_prompt=True, tokenize=False)
61+
texts = processor.apply_chat_template(
62+
conversations, add_generation_prompt=True, tokenize=False, enable_thinking=False
63+
)
64+
print(f"Texts: {texts}")
6265
audios, images, videos = process_mm_info(conversations, use_audio_in_video=use_audio_in_video)
6366

6467
inputs = processor(
@@ -99,17 +102,11 @@ def main(args):
99102

100103
if __name__ == "__main__":
101104
parser = argparse.ArgumentParser(description="Run quantized Qwen3Omni model")
102-
parser.add_argument(
103-
"--model_path",
104-
type=str,
105-
default="Qwen/Qwen3-Omni-30B-A3B-Instruct",
106-
help="Path to the base Qwen3Omni model (HF format)",
107-
)
108105
parser.add_argument(
109106
"--checkpoint_path",
110107
type=str,
111-
default="/home/scratch.arasane_hw/models/qwen3omni_nvfp4_qkv_disabled_text_bs512_calib512.pt",
112-
help="Path to the mto.save() quantized checkpoint",
108+
required=True,
109+
help="Path to the export_hf_checkpoint() quantized checkpoint directory",
113110
)
114111
parser.add_argument(
115112
"--prompt",

modelopt/torch/export/unified_export_hf.py

Lines changed: 55 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -177,6 +177,19 @@ def _output_hook(module, input, output):
177177
"This is required for requantization/resmoothing optimization. "
178178
"Please ensure the model architecture is supported or file an issue."
179179
)
180+
elif "qwen3omni" in model_type:
181+
# For Qwen3Omni, run on the thinker (language model) component
182+
# The model has structure: model.thinker.model.layers.*
183+
if hasattr(model, "thinker"):
184+
print(
185+
f"Running optimization on Qwen3Omni thinker with fake_input shape: {fake_input.shape}"
186+
)
187+
model.thinker(fake_input)
188+
else:
189+
raise ValueError(
190+
f"Cannot extract thinker from Qwen3Omni model (type: {model_type}). "
191+
"This is required for requantization/resmoothing optimization."
192+
)
180193
else:
181194
model(fake_input)
182195

@@ -248,6 +261,19 @@ def _export_quantized_weight(
248261
weight_quantizer: TensorQuantizer | SequentialQuantizer = getattr(
249262
sub_module, quantizer_attrs.weight_quantizer
250263
)
264+
265+
# Skip export if weight quantizer is disabled or has no amax (not calibrated)
266+
if not _is_enabled_quantizer(weight_quantizer):
267+
return
268+
269+
# Check if weight quantizer has calibrated amax
270+
def _has_amax(quantizer):
271+
if isinstance(quantizer, SequentialQuantizer):
272+
return any(hasattr(q, "_amax") and q._amax is not None for q in quantizer)
273+
return hasattr(quantizer, "_amax") and quantizer._amax is not None
274+
275+
if not _has_amax(weight_quantizer):
276+
return
251277
input_quantizer: TensorQuantizer | SequentialQuantizer | None = getattr(
252278
sub_module, quantizer_attrs.input_quantizer, None
253279
)
@@ -392,7 +418,11 @@ def _export_quantized_weight(
392418

393419

394420
def _export_hf_checkpoint(
395-
model: nn.Module, dtype: torch.dtype | None = None, is_modelopt_qlora: bool = False, **kwargs
421+
model: nn.Module,
422+
dtype: torch.dtype | None = None,
423+
is_modelopt_qlora: bool = False,
424+
pack_weights: bool = True,
425+
**kwargs,
396426
) -> tuple[dict[str, Any], dict[str, Any]]:
397427
"""Exports the torch model to the packed checkpoint with original HF naming.
398428
@@ -402,6 +432,7 @@ def _export_hf_checkpoint(
402432
model: the full torch model to export. The actual quantized model may be a submodule.
403433
dtype: the weights data type to export the unquantized layers or the default model data type if None.
404434
accelerator: the accelerator instance in case of distributed export setup.
435+
pack_weights: whether to pack quantized weights (False keeps original shapes for HF reload).
405436
406437
Returns:
407438
post_state_dict: Dict containing quantized weights
@@ -518,8 +549,9 @@ def _export_hf_checkpoint(
518549

519550
if get_quantization_format(sub_module) != QUANTIZATION_NONE:
520551
if is_quantlinear(sub_module):
521-
with fsdp2_aware_weight_update(model, sub_module, reshard=False):
522-
_export_quantized_weight(sub_module, dtype)
552+
if pack_weights:
553+
with fsdp2_aware_weight_update(model, sub_module, reshard=False):
554+
_export_quantized_weight(sub_module, dtype)
523555
elif (
524556
"Llama4TextExperts" in type(sub_module).__name__
525557
or "GptOssExperts" in type(sub_module).__name__
@@ -536,9 +568,10 @@ def _export_hf_checkpoint(
536568
quantizer_attrs=["gate_up_proj_input_quantizer", "down_proj_input_quantizer"],
537569
)
538570
# Export the quantized weights
539-
with fsdp2_aware_weight_update(model, sub_module, reshard=False):
540-
for weight_name in ["gate_up_proj", "down_proj"]:
541-
_export_quantized_weight(sub_module, dtype, weight_name)
571+
if pack_weights:
572+
with fsdp2_aware_weight_update(model, sub_module, reshard=False):
573+
for weight_name in ["gate_up_proj", "down_proj"]:
574+
_export_quantized_weight(sub_module, dtype, weight_name)
542575

543576
if accelerator is not None:
544577
# Gather state_dict from all ranks
@@ -579,7 +612,12 @@ def export_hf_checkpoint(
579612
return
580613

581614
try:
582-
post_state_dict, hf_quant_config = _export_hf_checkpoint(model, dtype)
615+
# Packed weights are only for TRT-LLM consumption
616+
# Set this to true if you want to save the weights in the original precision
617+
pack_weights = True
618+
post_state_dict, hf_quant_config = _export_hf_checkpoint(
619+
model, dtype, pack_weights=pack_weights
620+
)
583621

584622
if hf_quant_config is not None:
585623
# Save hf_quant_config.json for\ backward compatibility
@@ -588,6 +626,16 @@ def export_hf_checkpoint(
588626

589627
hf_quant_config = convert_hf_quant_config_format(hf_quant_config)
590628

629+
# Fix generation_config conflicts before saving
630+
# Some models have temperature/top_p/top_k set but do_sample=False which causes validation errors
631+
if hasattr(model, "generation_config") and model.generation_config is not None:
632+
gen_config = model.generation_config
633+
if not getattr(gen_config, "do_sample", True):
634+
# Remove sampling-related params when do_sample is False
635+
for attr in ["temperature", "top_p", "top_k"]:
636+
if hasattr(gen_config, attr):
637+
setattr(gen_config, attr, None)
638+
591639
# Save model
592640
model.save_pretrained(
593641
export_dir, state_dict=post_state_dict, save_modelopt_state=save_modelopt_state

modelopt/torch/utils/image_processor.py

Lines changed: 3 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -145,13 +145,9 @@ def preprocess_function(self, text: str) -> dict:
145145
Dictionary with tokenized inputs.
146146
"""
147147
# Build conversation in Qwen format (text-only)
148-
conversation = [
149-
{"role": "user", "content": [{"type": "text", "text": "/no_think " + text}]}
150-
]
151-
152-
# Apply chat template (tokenize=False to get formatted string)
148+
conversation = [{"role": "user", "content": [{"type": "text", "text": text}]}]
153149
formatted_text = self.tokenizer.apply_chat_template(
154-
conversation, add_generation_prompt=True, tokenize=False
150+
conversation, add_generation_prompt=True, tokenize=False, enable_thinking=False
155151
)
156152

157153
# Tokenize with the processor (no multimodal inputs)
@@ -212,10 +208,8 @@ def preprocess_function(self, examples):
212208
content.append({"type": "text", "text": question})
213209

214210
conversation = [{"role": "user", "content": content}]
215-
216-
# Apply chat template (tokenize=False to get string)
217211
text = self.tokenizer.apply_chat_template(
218-
conversation, add_generation_prompt=True, tokenize=False
212+
conversation, add_generation_prompt=True, tokenize=False, enable_thinking=False
219213
)
220214

221215
# Extract multimodal info using qwen_omni_utils

modelopt/torch/utils/video_dataset_utils.py

Lines changed: 26 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -114,32 +114,41 @@ def get_video_dataset_dataloader(
114114

115115
processed_dataset = None
116116

117-
# Try to load from cache
117+
# Try to load from cache (use torch.save/load to avoid Arrow 32-bit offset overflow)
118118
if cache_dir is not None:
119-
from datasets import load_from_disk
120-
121-
cache_path = os.path.join(cache_dir, f"{dataset_name}_n{num_samples}_processed")
119+
cache_path = os.path.join(cache_dir, f"{dataset_name}_n{num_samples}_processed.pt")
122120
if os.path.exists(cache_path):
123121
try:
124-
processed_dataset = load_from_disk(cache_path)
122+
from datasets import Dataset
123+
124+
processed_samples = torch.load(cache_path, weights_only=False)
125+
processed_dataset = Dataset.from_list(processed_samples)
125126
print(f"Loaded processed dataset from cache: {cache_path}")
126127
except Exception as e:
127128
print(f"Failed to load cache from {cache_path}: {e}. Reprocessing...")
128129
processed_dataset = None
129130

130131
# Process dataset if not loaded from cache
131132
if processed_dataset is None:
133+
from datasets import Dataset
134+
132135
dataset = _get_video_dataset(dataset_name, num_samples=num_samples)
133-
# Apply the preprocessing function to the dataset
134-
processed_dataset = dataset.map(
135-
processor.preprocess_function, batched=False, remove_columns=dataset.column_names
136-
)
137136

138-
# Save to cache if cache_dir is provided
137+
# Process samples manually to avoid Arrow 32-bit offset overflow
138+
# (dataset.map() uses Arrow internally which can't handle large nested lists)
139+
processed_samples = []
140+
for i, sample in enumerate(dataset):
141+
processed = processor.preprocess_function(sample)
142+
processed_samples.append(processed)
143+
if (i + 1) % 10 == 0:
144+
print(f"Processed {i + 1}/{len(dataset)} samples...")
145+
146+
processed_dataset = Dataset.from_list(processed_samples)
147+
148+
# Save to cache using torch.save to avoid Arrow 32-bit offset overflow
139149
if cache_dir is not None:
140150
os.makedirs(cache_dir, exist_ok=True)
141-
# Use num_shards=1 to avoid off-by-one sharding bug with complex nested structures
142-
processed_dataset.save_to_disk(cache_path, num_shards=1)
151+
torch.save(processed_samples, cache_path)
143152
print(f"Saved processed dataset to cache: {cache_path}")
144153

145154
# Create DataLoader with the custom collate function
@@ -204,9 +213,11 @@ def preprocess_function(self, examples):
204213
metadata = examples["json"]
205214
# Try to get a meaningful question from metadata
206215
category = metadata.get("content_fine_category", "")
207-
question = f"/no_think Describe what is happening in this video in detail. Category hint: {category}"
216+
question = (
217+
f"Describe what is happening in this video in detail. Category hint: {category}"
218+
)
208219
else:
209-
question = examples.get("question", "/no_think Describe this video in detail.")
220+
question = examples.get("question", "Describe this video in detail.")
210221

211222
# Build conversation in Qwen format
212223
content = []
@@ -226,10 +237,8 @@ def preprocess_function(self, examples):
226237
content.append({"type": "text", "text": question})
227238

228239
conversation = [{"role": "user", "content": content}]
229-
230-
# Apply chat template (tokenize=False to get string)
231240
text = self.tokenizer.apply_chat_template(
232-
conversation, add_generation_prompt=True, tokenize=False
241+
conversation, add_generation_prompt=True, tokenize=False, enable_thinking=False
233242
)
234243

235244
# Extract multimodal info using qwen_omni_utils

0 commit comments

Comments
 (0)