Skip to content

Commit 10f1140

Browse files
committed
fix issues caused by rebase and simplify
Signed-off-by: Zhiyu Cheng <zhiyuc@nvidia.com>
1 parent 7bce92e commit 10f1140

File tree

2 files changed

+30
-60
lines changed

2 files changed

+30
-60
lines changed

examples/llm_ptq/hf_ptq.py

Lines changed: 6 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,6 @@
6666
)
6767
from modelopt.torch.utils.image_processor import BaseImageProcessor, MllamaImageProcessor
6868
from modelopt.torch.utils.memory_monitor import launch_memory_monitor
69-
from modelopt.torch.utils.nemotron_vlm_dataset_utils import get_nemotron_vlm_dataset_dataloader
7069
from modelopt.torch.utils.speech_dataset_utils import get_speech_dataset_dataloader
7170
from modelopt.torch.utils.vlm_dataset_utils import get_vlm_dataset_dataloader
7271

@@ -142,7 +141,6 @@ def make_calib_dataloader(
142141
tokenizer: PreTrainedTokenizerBase | None,
143142
device: torch.device,
144143
model_type: str | None,
145-
full_model: torch.nn.Module | None = None,
146144
) -> tuple[DataLoader, str | None]:
147145
calib_dataloader = None
148146
first_text_speech_dataset = None
@@ -525,12 +523,6 @@ def mono_quantize(
525523
"Consider reducing calib_size to reduce calibration time.\n####\n"
526524
)
527525

528-
# Check if this is Nemotron-Parse
529-
config = full_model.config
530-
architectures = getattr(config, "architectures", [])
531-
is_nemotron_parse = any("nemotronparse" in arch.lower() for arch in architectures)
532-
original_forward = None # Track original forward method if we wrap it
533-
534526
# For Nemotron VL models, disable quantization of vision components
535527
if is_nemotron_vl_model:
536528
print("Disabling quantization for vision components in Nemotron VL model")
@@ -569,15 +561,8 @@ def mono_quantize(
569561
else:
570562
language_model = mtq.quantize(language_model, quant_cfg, forward_loop=calibrate_loop)
571563

572-
# Restore original forward method if we wrapped it for Nemotron-Parse
573-
if is_nemotron_parse and original_forward is not None:
574-
print("Restoring original forward method after calibration")
575-
language_model.forward = original_forward
576-
original_forward = None
577-
578-
# For VL models (except Nemotron-Parse), update full_model to use the quantized language model
579-
# For Nemotron-Parse, language_model IS full_model, so no update needed
580-
if is_nemotron_vl_model and language_model is not full_model:
564+
# For VL models, update full_model to use the quantized language model
565+
if is_nemotron_vl_model:
581566
language_model_lineage = get_language_model_from_vl(full_model)
582567
if language_model_lineage is not None:
583568
print("Updating full_model with quantized language_model...")
@@ -717,20 +702,10 @@ def pre_quantize(
717702
post-quantize generation.
718703
719704
"""
720-
# Check if this is Nemotron-Parse (encoder-decoder model)
721-
config = full_model.config
722-
architectures = getattr(config, "architectures", [])
723-
is_nemotron_parse = any("nemotronparse" in arch.lower() for arch in architectures)
724-
725705
# Only run single sample for preview
726-
# For Nemotron-Parse, use decoder_input_ids instead of input_ids
727-
sample_batch = next(iter(calib_dataloader))
728-
if is_nemotron_parse and "decoder_input_ids" in sample_batch:
729-
preview_input_ids = sample_batch["decoder_input_ids"][0:1]
730-
elif model_type == "whisper":
731-
preview_input_ids = sample_batch["input_features"][0:1]
732-
else:
733-
preview_input_ids = sample_batch["input_ids"][0:1]
706+
preview_input_ids = next(iter(calib_dataloader))[
707+
"input_features" if model_type == "whisper" else "input_ids"
708+
][0:1]
734709

735710
# Generate preview before quantization
736711
if model_type == "deepseek":
@@ -901,7 +876,7 @@ def quantize_main(
901876
print(f"Use calib batch_size {args.batch_size}")
902877

903878
calib_dataloader, first_text_speech_dataset = make_calib_dataloader(
904-
args, language_model, processor, tokenizer, device, model_type, full_model
879+
args, language_model, processor, tokenizer, device, model_type
905880
)
906881

907882
# Detect if this is a Nemotron VL model using architecture-based detection

modelopt/torch/export/unified_export_hf.py

Lines changed: 24 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -148,13 +148,13 @@ def _collect_shared_input_modules(
148148
def _input_hook(module, input, output):
149149
"""Update dictionary with list of all modules that share the same input."""
150150
if len(input) > 0 and isinstance(input[0], torch.Tensor):
151-
# TODO: Handle DBRX MoE case
152-
input_to_linear[input[0]].append(module)
151+
# TODO: Handle DBRX MoE case
152+
input_to_linear[input[0]].append(module)
153153

154154
def _output_hook(module, input, output):
155155
"""Update dictionary with mapping of layernorms and their outputs."""
156156
if output_to_layernorm is not None and isinstance(output, torch.Tensor):
157-
output_to_layernorm[output] = module
157+
output_to_layernorm[output] = module
158158

159159
handles = []
160160

@@ -323,29 +323,29 @@ def llm_dummy_forward():
323323
if is_vl_model and ("nemotron" in model_type or is_nemotron_parse):
324324
# For Nemotron VL models (including Nemotron-Parse), run optimization on just the
325325
# language model/decoder. This avoids needing pixel_values for the vision encoder.
326-
language_model_lineage = get_language_model_from_vl(model)
326+
language_model_lineage = get_language_model_from_vl(model)
327327

328-
if language_model_lineage is not None:
329-
language_model = language_model_lineage[-1]
330-
print(
331-
f"Running optimization on language model with fake_input shape: {fake_input.shape}"
332-
)
333-
# For Nemotron-Parse decoder, force use_cache=False to avoid tuple index errors
334-
if is_nemotron_parse:
335-
language_model(fake_input, use_cache=False)
336-
else:
337-
language_model(fake_input)
328+
if language_model_lineage is not None:
329+
language_model = language_model_lineage[-1]
330+
print(
331+
f"Running optimization on language model with fake_input shape: {fake_input.shape}"
332+
)
333+
# For Nemotron-Parse decoder, force use_cache=False to avoid tuple index errors
334+
if is_nemotron_parse:
335+
language_model(fake_input, use_cache=False)
338336
else:
339-
raise ValueError(
340-
f"Cannot extract language_model from Nemotron VL model (type: {model_type}). "
341-
"This is required for requantization/resmoothing optimization. "
342-
"Please ensure the model architecture is supported or file an issue."
343-
)
337+
language_model(fake_input)
338+
else:
339+
raise ValueError(
340+
f"Cannot extract language_model from Nemotron VL model (type: {model_type}). "
341+
"This is required for requantization/resmoothing optimization. "
342+
"Please ensure the model architecture is supported or file an issue."
343+
)
344344
elif getattr(model.config, "is_encoder_decoder", False):
345345
# For other encoder-decoder models (non-VL), pass both encoder and decoder input ids
346346
model(fake_input, decoder_input_ids=decoder_fake_input)
347-
else:
348-
model(fake_input)
347+
else:
348+
model(fake_input)
349349

350350
input_to_linear, output_to_layernorm = _collect_shared_input_modules(
351351
model, llm_dummy_forward, collect_layernorms=True
@@ -440,19 +440,14 @@ def _export_quantized_weight(
440440
weight_scaling_factor,
441441
)
442442

443-
sub_module.register_buffer(
444-
quantizer_attrs.weight_scale,
445-
weight_scaling_factor,
446-
)
447-
448443
if hasattr(input_quantizer, "_amax") or (
449444
input_quantizer is not None
450445
and hasattr(input_quantizer, "amax")
451446
and input_quantizer.amax is not None
452447
):
453448
assert input_quantizer is not None
454449
if hasattr(input_quantizer, "_amax") and input_quantizer._amax is not None:
455-
input_quantizer._amax = input_quantizer._amax.to(torch.float32)
450+
input_quantizer._amax = input_quantizer._amax.to(torch.float32)
456451

457452
sub_module.register_buffer(
458453
quantizer_attrs.input_scale,
@@ -468,7 +463,7 @@ def _export_quantized_weight(
468463
):
469464
assert output_quantizer is not None
470465
if hasattr(output_quantizer, "_amax") and output_quantizer._amax is not None:
471-
output_quantizer._amax = output_quantizer._amax.to(torch.float32)
466+
output_quantizer._amax = output_quantizer._amax.to(torch.float32)
472467
else:
473468
# Register weight_scale and input_scale
474469
if quantization_format == QUANTIZATION_FP8_PB_REAL:
@@ -485,7 +480,7 @@ def _export_quantized_weight(
485480
)
486481
sub_module.register_buffer(quantizer_attrs.weight_scale, e8m0_scale)
487482
if hasattr(weight_quantizer, "_scale") and weight_quantizer._scale is not None:
488-
del weight_quantizer._scale
483+
del weight_quantizer._scale
489484
else:
490485
sub_module.register_buffer(
491486
quantizer_attrs.weight_scale, get_weight_scaling_factor(sub_module, weight_name)

0 commit comments

Comments
 (0)