Skip to content

Commit 6f18490

Browse files
authored
Improve AWQ init speed (#748)
## What does this PR do? **Type of change:** ?Improvement<!-- Use one of the following: Bug fix, new feature, new example, new tests, documentation. --> **Overview:** Improve speed of accessing weight through enable_weight_access_and_writeback in AWQ helper init. This change reduces the time complexity from O(num_modules^2) to O(num_modules) and the runtime from ~1hour to 30 seconds. ## Usage <!-- You can potentially add a usage example below. --> ```python # Add a code snippet demonstrating how to use this ``` ## Testing <!-- Mention how have you tested your change if applicable. --> python hf_ptq.py --pyt_ckpt_path /home/scratch.omniml_data_1/models/qwen/Qwen3-30B-A3B-Instruct-2507 --qformat int4_awq ## Before your PR is "*Ready for review*" <!-- If you haven't finished some of the above items you can still open `Draft` PR. --> - **Make sure you read and follow [Contributor guidelines](https://github.com/NVIDIA/Model-Optimizer/blob/main/CONTRIBUTING.md)** and your commits are signed. - **Is this change backward compatible?**: Yes/No <!--- If No, explain why. --> - **Did you write any new necessary tests?**: Yes/No - **Did you add or update any necessary documentation?**: Yes/No - **Did you update [Changelog](https://github.com/NVIDIA/Model-Optimizer/blob/main/CHANGELOG.rst)?**: Yes/No <!--- Only for new features, API changes, critical bug fixes or bw breaking changes. --> ## Additional Information <!-- E.g. related issue. --> Signed-off-by: weimingc <17592131+meenchen@users.noreply.github.com>
1 parent 9c24e2c commit 6f18490

File tree

2 files changed

+40
-16
lines changed

2 files changed

+40
-16
lines changed

modelopt/torch/quantization/model_calib.py

Lines changed: 15 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -532,9 +532,11 @@ def awq(
532532
awq_clip(model, forward_loop, **kwargs)
533533

534534
# Special handling for SequentialQuantizer
535+
# Pre-compute name_to_module dict to avoid O(n^2) complexity in enable_weight_access_and_writeback
536+
name_to_module = dict(model.named_modules())
535537
for name, module in model.named_modules():
536538
if is_quantized_linear(module) and isinstance(module.weight_quantizer, SequentialQuantizer):
537-
with enable_weight_access_and_writeback(module, model):
539+
with enable_weight_access_and_writeback(module, model, name_to_module):
538540
max_calibrate(module, lambda linear: linear.weight_quantizer(module.weight))
539541

540542

@@ -606,8 +608,9 @@ def get_weight_scale(weight, block_size=None):
606608
weight = F.pad(weight, (0, block_size - org_shape[-1] % block_size), "constant", 0)
607609
org_shape = weight.shape
608610
weight = weight.contiguous().view(-1, block_size)
609-
weight_abs_amax = weight.abs().amax(dim=1, keepdim=True)
610-
scale = weight.abs() / (weight_abs_amax + torch.finfo(weight.dtype).tiny)
611+
weight_abs = weight.abs() # Cache to avoid redundant computation
612+
weight_abs_amax = weight_abs.amax(dim=1, keepdim=True)
613+
scale = weight_abs / (weight_abs_amax + torch.finfo(weight.dtype).tiny)
611614
scale = scale.view(org_shape)
612615
if slice_after_padding is not None:
613616
scale = scale[..., slice_after_padding]
@@ -701,9 +704,11 @@ def forward(self, input, *args, **kwargs):
701704
# Now forward the actual output without any quantization
702705
return out_actual
703706

704-
for name, module in model.named_modules():
707+
# Pre-compute name_to_module dict ONCE to avoid O(n^2) complexity in enable_weight_access_and_writeback
708+
name_to_module = dict(model.named_modules())
709+
for name, module in name_to_module.items():
705710
if is_quantized_linear(module) and module.weight_quantizer.is_enabled:
706-
with enable_weight_access_and_writeback(module, model):
711+
with enable_weight_access_and_writeback(module, model, name_to_module):
707712
module.awq_lite = AWQLiteHelper(module, name)
708713
module.awq_lite.setup()
709714

@@ -793,7 +798,7 @@ def postprocess(module, name):
793798
f" {name}. Please provide a valid `forward_loop` function that can be used to"
794799
" forward data through the model many times."
795800
)
796-
with enable_weight_access_and_writeback(module, model):
801+
with enable_weight_access_and_writeback(module, model, name_to_module):
797802
postprocess(module, name)
798803

799804
module.awq_lite.cleanup()
@@ -973,14 +978,16 @@ def forward(name, self, input, *args, **kwargs):
973978
self.weight_quantizer.disable()
974979
return self._forward_no_awq(input, *args, **kwargs)
975980

981+
# Pre-compute name_to_module dict to avoid O(n^2) complexity in enable_weight_access_and_writeback
982+
name_to_module = dict(model.named_modules())
976983
for name, module in model.named_modules():
977984
if (
978985
is_quantized_linear(module)
979986
and module.weight_quantizer.is_enabled
980987
and module.weight_quantizer.block_sizes is not None
981988
):
982989
bind_forward_method(module, partial(forward, name), "_forward_no_awq")
983-
with enable_weight_access_and_writeback(module, model):
990+
with enable_weight_access_and_writeback(module, model, name_to_module):
984991
module.awq_clip = AWQClipHelper(module)
985992

986993
print_rank_0("awq_clip: Estimating parameters...")
@@ -1004,7 +1011,7 @@ def postprocess(module):
10041011
for name, module in model.named_modules():
10051012
if is_quantized_linear(module) and hasattr(module, "awq_clip"):
10061013
if module.awq_clip.num_tokens > 0:
1007-
with enable_weight_access_and_writeback(module, model):
1014+
with enable_weight_access_and_writeback(module, model, name_to_module):
10081015
postprocess(module)
10091016

10101017
if not debug:

modelopt/torch/quantization/utils.py

Lines changed: 25 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -396,19 +396,30 @@ def _get_fsdp2_mesh(module: nn.Module):
396396
return fsdp_state._fsdp_param_group.post_forward_mesh_info.mesh
397397

398398

399-
def _get_module_name(module: nn.Module, root_model: nn.Module):
400-
name_to_module = dict(root_model.named_modules())
399+
def _get_module_name(module: nn.Module, root_model: nn.Module, name_to_module: dict | None = None):
400+
if name_to_module is None:
401+
name_to_module = dict(root_model.named_modules())
401402
target_module_name = next((name for name, m in name_to_module.items() if m is module), None)
402403
return target_module_name
403404

404405

405-
def _get_enclosing_fsdp_module(module: nn.Module, root_model: nn.Module):
406-
"""Get the enclosing FSDP module for a given module."""
406+
def _get_enclosing_fsdp_module(
407+
module: nn.Module, root_model: nn.Module, name_to_module: dict | None = None
408+
):
409+
"""Get the enclosing FSDP module for a given module.
410+
411+
Args:
412+
module: The module to find the enclosing FSDP for.
413+
root_model: The root model containing the module.
414+
name_to_module: Optional pre-computed dict mapping names to modules (for performance).
415+
"""
407416
if isinstance(module, FSDPModule):
408417
return module
409418

410-
name_to_module = dict(root_model.named_modules())
411-
target_module_name = _get_module_name(module, root_model)
419+
if name_to_module is None:
420+
name_to_module = dict(root_model.named_modules())
421+
422+
target_module_name = _get_module_name(module, root_model, name_to_module)
412423

413424
if target_module_name is None:
414425
raise ValueError(f"Module {module} not found in the root model {root_model}.")
@@ -469,13 +480,19 @@ def fsdp2_weight_access_and_writeback_context(module: nn.Module, root_model: nn.
469480

470481

471482
@contextmanager
472-
def enable_weight_access_and_writeback(module, root_model):
483+
def enable_weight_access_and_writeback(module, root_model, name_to_module: dict | None = None):
473484
"""Enable weight access and writeback for a module.
474485
475486
Useful for modules with weight not intact such as Linear layer in FSDP wrapped model or
476487
HF accelerate CPU off-loaded models.
488+
489+
Args:
490+
module: The module to access weights for.
491+
root_model: The root model containing the module.
492+
name_to_module: Optional pre-computed dict mapping names to modules (for performance).
493+
If not provided, will be computed on-the-fly.
477494
"""
478-
if _get_enclosing_fsdp_module(module, root_model) is not None:
495+
if _get_enclosing_fsdp_module(module, root_model, name_to_module) is not None:
479496
context = fsdp2_weight_access_and_writeback_context(module, root_model)
480497
elif is_quantized_parallel_linear(module) and hasattr(module, "_hf_tp_plan"):
481498
# HF transformers TP sharded linear layer

0 commit comments

Comments
 (0)