Skip to content

Commit 20a4b33

Browse files
committed
Address comments
Signed-off-by: ajrasane <131806219+ajrasane@users.noreply.github.com>
1 parent 9e3b399 commit 20a4b33

File tree

7 files changed

+288
-335
lines changed

7 files changed

+288
-335
lines changed

examples/llm_ptq/example_utils.py

Lines changed: 13 additions & 65 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@
4545
except ImportError:
4646
snapshot_download = None
4747

48-
from modelopt.torch.export.model_utils import MODEL_NAME_TO_TYPE
48+
from modelopt.torch.export.model_utils import match_model_type_by_name
4949
from modelopt.torch.utils.dataset_utils import get_dataset_dataloader
5050
from modelopt.torch.utils.image_processor import (
5151
BaseImageProcessor,
@@ -95,19 +95,13 @@ def get_model_type_from_config(model_path: str) -> str | None:
9595
config = json.load(f)
9696

9797
# Check architectures field first
98-
architectures = config.get("architectures", [])
99-
for arch in architectures:
100-
for key, model_type in MODEL_NAME_TO_TYPE.items():
101-
if key.lower() in arch.lower():
102-
return model_type
98+
for arch in config.get("architectures", []):
99+
result = match_model_type_by_name(arch)
100+
if result is not None:
101+
return result
103102

104103
# Fallback to model_type field
105-
model_type_field = config.get("model_type", "")
106-
for key, model_type in MODEL_NAME_TO_TYPE.items():
107-
if key.lower() in model_type_field.lower():
108-
return model_type
109-
110-
return None
104+
return match_model_type_by_name(config.get("model_type", ""))
111105

112106

113107
def get_sampling_params_from_config(model_path: str) -> dict:
@@ -164,10 +158,13 @@ def ensure_tokenizer_files(model_path: str, source_model_id: str) -> None:
164158

165159
print(f"Copying missing tokenizer files from {source_model_id}...")
166160
# Download only tokenizer files from HF
167-
cache_dir = snapshot_download(
168-
source_model_id,
169-
allow_patterns=TOKENIZER_FILES,
170-
)
161+
if os.path.isdir(source_model_id):
162+
cache_dir = source_model_id
163+
else:
164+
cache_dir = snapshot_download(
165+
source_model_id,
166+
allow_patterns=TOKENIZER_FILES,
167+
)
171168

172169
for fname in TOKENIZER_FILES:
173170
src = os.path.join(cache_dir, fname)
@@ -992,55 +989,6 @@ def copy_custom_model_files(source_path: str, export_path: str, trust_remote_cod
992989
print("No custom model files found to copy")
993990

994991

995-
def patch_config_for_unified_export(model_type: str, export_path: str) -> None:
996-
"""Patch config files to add missing exclusion patterns for unified HF export.
997-
998-
This function adds missing exclusion patterns for modules that should not be quantized
999-
(e.g., audio tower, visual encoder, lm_head) to both hf_quant_config.json and config.json.
1000-
1001-
Args:
1002-
export_path: Path to the exported model directory.
1003-
"""
1004-
if model_type == "qwen3omni":
1005-
missing_patterns = [
1006-
"thinker.audio_tower*",
1007-
"thinker.visual*",
1008-
"thinker.lm_head",
1009-
]
1010-
1011-
# (filename, path_to_exclude_list)
1012-
configs = [
1013-
("hf_quant_config.json", ["quantization", "exclude_modules"]),
1014-
("config.json", ["quantization_config", "ignore"]),
1015-
]
1016-
1017-
for filename, keys in configs:
1018-
filepath = os.path.join(export_path, filename)
1019-
if not os.path.exists(filepath):
1020-
continue
1021-
try:
1022-
with open(filepath) as f:
1023-
config = json.load(f)
1024-
1025-
# Navigate to nested key
1026-
target = config
1027-
for key in keys[:-1]:
1028-
target = target.get(key, {})
1029-
1030-
exclude_list = target.get(keys[-1])
1031-
if exclude_list is None:
1032-
continue
1033-
1034-
added = [p for p in missing_patterns if p not in exclude_list]
1035-
if added:
1036-
exclude_list.extend(added)
1037-
with open(filepath, "w") as f:
1038-
json.dump(config, f, indent=2)
1039-
print(f"Patched {filename} with exclusions: {added}")
1040-
except Exception as e:
1041-
print(f"Warning: Failed to patch {filename}: {e}")
1042-
1043-
1044992
def get_qwen3omni_dataloader(
1045993
dataset_name: str | list[str] | None,
1046994
processor: Qwen3OmniImageProcessor | None,

examples/llm_ptq/hf_ptq.py

Lines changed: 35 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
import random
1919
import time
2020
import warnings
21+
from collections import namedtuple
2122
from typing import Any
2223

2324
import numpy as np
@@ -35,7 +36,6 @@
3536
is_enc_dec,
3637
is_nemotron_vl,
3738
load_mtp_weights,
38-
patch_config_for_unified_export,
3939
run_nemotron_vl_preview,
4040
)
4141
from torch.utils.data import DataLoader
@@ -735,9 +735,6 @@ def export_quantized(
735735
extra_state_dict=mtp_state_dict,
736736
)
737737

738-
# Exclude non-quantized modules in config.json and hf_quant_config.json
739-
patch_config_for_unified_export(model_type, export_path)
740-
741738
# Restore default padding and export the tokenizer as well.
742739
if tokenizer is not None:
743740
tokenizer.padding_side = default_padding_side
@@ -757,6 +754,23 @@ def export_quantized(
757754
)
758755

759756

757+
PreQuantizeResult = namedtuple(
758+
"PreQuantizeResult", ["preview_input_ids", "generated_ids_before_ptq", "calib_batch"]
759+
)
760+
761+
762+
def _qwen3omni_generate(model, calib_batch):
763+
"""Run Qwen3Omni generate and unpack the result.
764+
765+
Qwen3Omni returns a (text_ids, audio) tuple; text_ids may have a .sequences attribute.
766+
"""
767+
result = model.generate(**calib_batch, return_audio=False, thinker_max_new_tokens=100)
768+
if isinstance(result, tuple):
769+
text_ids, _ = result
770+
return text_ids.sequences if hasattr(text_ids, "sequences") else text_ids
771+
return result
772+
773+
760774
def pre_quantize(
761775
args: argparse.Namespace,
762776
full_model: torch.nn.Module,
@@ -799,20 +813,15 @@ def pre_quantize(
799813
allow_fallback=False,
800814
)
801815
elif model_type == "qwen3omni":
802-
# Qwen3Omni returns (text_ids, audio) tuple; text_ids has .sequences
803-
# Pass full batch with all multimodal inputs
804-
result = full_model.generate(**calib_batch, return_audio=False, thinker_max_new_tokens=100)
805-
if isinstance(result, tuple):
806-
text_ids, _ = result
807-
generated_ids_before_ptq = (
808-
text_ids.sequences if hasattr(text_ids, "sequences") else text_ids
809-
)
810-
else:
811-
generated_ids_before_ptq = result
816+
# Use only a single sample for preview generation to avoid OOM
817+
single_sample = {
818+
k: v[0:1] if isinstance(v, torch.Tensor) else v for k, v in calib_batch.items()
819+
}
820+
generated_ids_before_ptq = _qwen3omni_generate(full_model, single_sample)
812821
else:
813822
generated_ids_before_ptq = full_model.generate(preview_input_ids, max_new_tokens=100)
814823

815-
return preview_input_ids, generated_ids_before_ptq, calib_batch
824+
return PreQuantizeResult(preview_input_ids, generated_ids_before_ptq, calib_batch)
816825

817826

818827
def post_quantize(
@@ -861,25 +870,23 @@ def post_quantize(
861870
"""
862871

863872
if args.verbose:
864-
mtq.print_quant_summary(full_model, save_path=args.quant_summary_path)
865-
save_expert_token_count_table(full_model, args.export_path)
873+
try:
874+
mtq.print_quant_summary(full_model, save_path=args.quant_summary_path)
875+
save_expert_token_count_table(full_model, args.export_path)
876+
except Exception as e:
877+
print(f"Warning: Failed to print quant summary: {e}")
866878

867879
# Run some samples
868880
torch.cuda.empty_cache()
869881
generated_ids_after_ptq = None
870882
if generated_ids_before_ptq is None:
871883
pass
872-
elif model_type == "qwen3omni":
873-
# Qwen3Omni returns (text_ids, audio) tuple; text_ids has .sequences
874-
# Pass full batch with all multimodal inputs
875-
result = full_model.generate(**calib_batch, return_audio=False, thinker_max_new_tokens=100)
876-
if isinstance(result, tuple):
877-
text_ids, _ = result
878-
generated_ids_after_ptq = (
879-
text_ids.sequences if hasattr(text_ids, "sequences") else text_ids
880-
)
881-
else:
882-
generated_ids_after_ptq = result
884+
elif model_type == "qwen3omni" and calib_batch is not None:
885+
# Use only a single sample for preview generation to avoid OOM
886+
single_sample = {
887+
k: v[0:1] if isinstance(v, torch.Tensor) else v for k, v in calib_batch.items()
888+
}
889+
generated_ids_after_ptq = _qwen3omni_generate(full_model, single_sample)
883890
elif model_type != "llama4" and not is_nemotron_vl_model:
884891
# Our fake quantizer may not be fully compatible with torch.compile.
885892
generated_ids_after_ptq = full_model.generate(preview_input_ids, max_new_tokens=100)

modelopt/torch/export/model_utils.py

Lines changed: 22 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -67,17 +67,35 @@
6767
{MODEL_NAME_TO_TYPE=}
6868
"""
6969

70-
__all__ = ["get_language_model_from_vl", "get_model_type", "is_multimodal_model"]
70+
__all__ = [
71+
"get_language_model_from_vl",
72+
"get_model_type",
73+
"is_multimodal_model",
74+
"match_model_type_by_name",
75+
]
7176

7277

73-
def get_model_type(model):
74-
"""Try get the model type from the model name. If not found, return None."""
78+
def match_model_type_by_name(name: str) -> str | None:
79+
"""Match a model type from MODEL_NAME_TO_TYPE by case-insensitive substring match.
80+
81+
Args:
82+
name: String to match against (e.g. class name, architecture string, model_type field).
83+
84+
Returns:
85+
Matched model type string, or None.
86+
"""
87+
name_lower = name.lower()
7588
for k, v in MODEL_NAME_TO_TYPE.items():
76-
if k.lower() in type(model).__name__.lower():
89+
if k.lower() in name_lower:
7790
return v
7891
return None
7992

8093

94+
def get_model_type(model):
95+
"""Try get the model type from the model name. If not found, return None."""
96+
return match_model_type_by_name(type(model).__name__)
97+
98+
8199
def is_multimodal_model(model):
82100
"""Check if a model is a Vision-Language Model (VLM) or multimodal model.
83101

modelopt/torch/export/unified_export_hf.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,7 @@
8787
QUANTIZATION_W4A8_AWQ,
8888
QUANTIZATION_W4A8_NVFP4_FP8,
8989
)
90-
from .model_utils import get_language_model_from_vl, is_multimodal_model
90+
from .model_utils import get_language_model_from_vl, get_model_type, is_multimodal_model
9191
from .plugins import SpeculativeDecodingExporter, has_spec_opt
9292
from .quant_utils import (
9393
fuse_prequant_layernorm,
@@ -781,6 +781,16 @@ def _export_transformers_checkpoint(
781781
exclude_modules.append(pattern)
782782
print(f"Adding MTP layer to quantization_config ignore: {pattern}")
783783

784+
# Add model-specific non-quantized module exclusions
785+
_model_type_exclusions = {
786+
"qwen3omni": ["thinker.audio_tower*", "thinker.visual*", "thinker.lm_head"],
787+
}
788+
model_type = get_model_type(model)
789+
for pattern in _model_type_exclusions.get(model_type, []):
790+
exclude_modules = quant_config["quantization"].setdefault("exclude_modules", [])
791+
if pattern not in exclude_modules:
792+
exclude_modules.append(pattern)
793+
784794
# Safety net: sync any gate/up weight quantizer amaxes that
785795
# requantize_resmooth_fused_llm_layers did not reach (e.g. experts not
786796
# activated during the dummy forward, or non-standard expert naming).
@@ -1185,6 +1195,8 @@ def export_hf_checkpoint(
11851195

11861196
# Fix generation_config conflicts before saving
11871197
# Some models have temperature/top_p/top_k set but do_sample=False which causes validation errors
1198+
# Restore the original value after save to avoid mutating the caller's model.
1199+
_gen_config_restore = None
11881200
if hasattr(model, "generation_config") and model.generation_config is not None:
11891201
gen_config = model.generation_config
11901202
if not getattr(gen_config, "do_sample", True):
@@ -1193,6 +1205,7 @@ def export_hf_checkpoint(
11931205
getattr(gen_config, attr, None) is not None
11941206
for attr in ["temperature", "top_p", "top_k"]
11951207
):
1208+
_gen_config_restore = gen_config.do_sample
11961209
gen_config.do_sample = True
11971210

11981211
# Save model
@@ -1211,6 +1224,8 @@ def export_hf_checkpoint(
12111224
)
12121225
finally:
12131226
_unpatch_revert_weight_conversion(_patches)
1227+
if _gen_config_restore is not None:
1228+
model.generation_config.do_sample = _gen_config_restore
12141229

12151230
original_config = f"{export_dir}/config.json"
12161231
config_data = {}

0 commit comments

Comments
 (0)