Skip to content

Commit bb5771b

Browse files
kevalmorabia97danielkorzekwa
authored andcommitted
Move parallel_state init and warnings to Quant DynamicModule + MBridge pruning doc update (#854)
## What does this PR do? **Type of change:** Minor improvement <!-- Use one of the following: Bug fix, new feature, new example, new tests, documentation. --> Only quantization DynamicModules use the parallel_state attribute so for all other model opt methods, we see a parallel state not initialized warning which could be confusing hence moving it to QuantModule class instead Minor update to MBridge pruning docs ## Testing <!-- Mention how have you tested your change if applicable. --> N/A ## Before your PR is "*Ready for review*" <!-- If you haven't finished some of the above items you can still open `Draft` PR. --> - **Make sure you read and follow [Contributor guidelines](https://github.com/NVIDIA/Model-Optimizer/blob/main/CONTRIBUTING.md)** and your commits are signed. - **Is this change backward compatible?**: Yes <!--- If No, explain why. --> - **Did you write any new necessary tests?**: Yes - **Did you add or update any necessary documentation?**: Yes - **Did you update [Changelog](https://github.com/NVIDIA/Model-Optimizer/blob/main/CHANGELOG.rst)?**: Yes <!--- Only for new features, API changes, critical bug fixes or bw breaking changes. --> ## Additional Information <!-- E.g. related issue. --> <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit * **New Features** * Distributed parallel state support is now available in quantization workflows for multi-GPU training. * **Bug Fixes** * Improved resource cleanup in distributed training to ensure proper environment finalization. * **Documentation** * Updated example paths and added new manual pruning configuration examples. <!-- end of auto-generated comment: release notes by coderabbit.ai --> --------- Signed-off-by: Keval Morabia <28916987+kevalmorabia97@users.noreply.github.com> Signed-off-by: Daniel Korzekwa <dkorzekwa@nvidia.com>
1 parent 33d4d27 commit bb5771b

5 files changed

Lines changed: 73 additions & 53 deletions

File tree

CHANGELOG.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -226,7 +226,7 @@ NVIDIA Model Optimizer Changelog (Linux)
226226
- Add support for UNet ONNX quantization.
227227
- Enable ``concat_elimination`` pass by default to improve the performance of quantized ONNX models.
228228
- Enable Redundant Cast elimination pass by default in :meth:`moq.quantize <modelopt.onnx.quantization.quantize>`.
229-
- Add new attribute ``parallel_state`` to :class:`DynamicModule <modelopt.torch.opt.dynamic.DynamicModule>` to support distributed parallelism such as data parallel and tensor parallel.
229+
- Add new attribute ``parallel_state`` to :class:`QuantModule <modelopt.torch.quantization.nn.modules.quant_module.QuantModule>` to support distributed parallelism such as data parallel and tensor parallel.
230230
- Add MXFP8, NVFP4 quantized ONNX export support.
231231
- Add new example for torch quantization to ONNX for MXFP8, NVFP4 precision.
232232

examples/megatron_bridge/README.md

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ This directory contains examples of using Model Optimizer with [NeMo Megatron-Br
1818

1919
Running these examples requires many additional dependencies to be installed (e.g., Megatron-Bridge, Megatron-core, etc.), hence we strongly recommend directly using the NeMo container (e.g., `nvcr.io/nvidia/nemo:26.02`) which has all the dependencies installed.
2020

21-
To get the latest ModelOpt features and examples, you can mount your latest ModelOpt cloned repository to the container at `/opt/Model-Optimizer` or pull the latest changes once inside the docker container (`cd /opt/Model-Optimizer && git checkout main && git pull`).
21+
To get the latest ModelOpt features and examples, you can mount your latest ModelOpt cloned repository to the container at `/opt/Megatron-Bridge/3rdparty/Model-Optimizer` or pull the latest changes once inside the docker container (`cd /opt/Megatron-Bridge/3rdparty/Model-Optimizer && git checkout main && git pull`).
2222

2323
## Pruning
2424

@@ -30,17 +30,27 @@ Example usage to prune Qwen3-8B to 6B on 2-GPUs (Pipeline Parallelism = 2) while
3030
top-10 candidates are evaluated for MMLU score (5% sampled data) to select the best model.
3131

3232
```bash
33-
torchrun --nproc_per_node 2 /opt/Model-Optimizer/examples/megatron_bridge/prune_minitron.py \
33+
torchrun --nproc_per_node 2 /opt/Megatron-Bridge/3rdparty/Model-Optimizer/examples/megatron_bridge/prune_minitron.py \
3434
--hf_model_name_or_path Qwen/Qwen3-8B \
3535
--prune_target_params 6e9 \
3636
--hparams_to_skip num_attention_heads \
3737
--output_hf_path /tmp/Qwen3-8B-Pruned-6B
3838
```
3939

40+
Example usage for manually pruning to a specific architecture using following defaults:
41+
1024 samples from [`nemotron-post-training-dataset-v2`](https://huggingface.co/datasets/nvidia/Nemotron-Post-Training-Dataset-v2) for calibration.
42+
43+
```bash
44+
torchrun --nproc_per_node 2 /opt/Megatron-Bridge/3rdparty/Model-Optimizer/examples/megatron_bridge/prune_minitron.py \
45+
--hf_model_name_or_path Qwen/Qwen3-8B \
46+
--prune_export_config '{"hidden_size": 3584, "ffn_hidden_size": 9216}' \
47+
--output_hf_path /tmp/Qwen3-8B-Pruned-6B-manual
48+
```
49+
4050
To see the full usage for advanced configurations, run:
4151

4252
```bash
43-
python /opt/Model-Optimizer/examples/megatron_bridge/prune_minitron.py --help
53+
python /opt/Megatron-Bridge/3rdparty/Model-Optimizer/examples/megatron_bridge/prune_minitron.py --help
4454
```
4555

4656
> [!TIP]

examples/megatron_bridge/prune_minitron.py

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,7 @@ def get_args() -> argparse.Namespace:
9191

9292
# Pruning parameters
9393
parser.add_argument(
94-
"--prune_intermediate_checkpoint",
94+
"--prune_intermediate_ckpt",
9595
type=str,
9696
default=None,
9797
help=(
@@ -169,18 +169,16 @@ def get_args() -> argparse.Namespace:
169169
args = parser.parse_args()
170170

171171
# Post-process arguments
172-
if args.prune_intermediate_checkpoint is None:
172+
if args.prune_intermediate_ckpt is None:
173173
if args.output_megatron_path:
174-
args.prune_intermediate_checkpoint = (
174+
args.prune_intermediate_ckpt = (
175175
f"{args.output_megatron_path}/modelopt_pruning_scores.pth"
176176
)
177177
elif args.output_hf_path:
178-
args.prune_intermediate_checkpoint = (
179-
f"{args.output_hf_path}/modelopt_pruning_scores.pth"
180-
)
178+
args.prune_intermediate_ckpt = f"{args.output_hf_path}/modelopt_pruning_scores.pth"
181179
print_rank_0(
182180
"No checkpoint provided to cache intermediate pruning scores. "
183-
f"Setting to: {args.prune_intermediate_checkpoint}"
181+
f"Setting to: {args.prune_intermediate_ckpt}"
184182
)
185183

186184
if args.prune_export_config:
@@ -247,7 +245,7 @@ def main(args: argparse.Namespace):
247245

248246
pruning_config = {
249247
"forward_loop": forward_loop,
250-
"checkpoint": args.prune_intermediate_checkpoint,
248+
"checkpoint": args.prune_intermediate_ckpt,
251249
}
252250
if args.prune_target_params is not None:
253251
# Restrict search space to a smaller set of candidates
@@ -377,8 +375,8 @@ def score_func_mmlu(m):
377375

378376

379377
if __name__ == "__main__":
380-
args = get_args()
381378
dist.setup()
379+
args = get_args()
382380
try:
383381
main(args)
384382
finally:

modelopt/torch/opt/dynamic.py

Lines changed: 0 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,6 @@
3131
from torch.nn.parameter import Parameter
3232

3333
from modelopt.torch.utils import get_unwrapped_name, is_channels_last, unwrap_model
34-
from modelopt.torch.utils.distributed import ParallelState
3534
from modelopt.torch.utils.network import bind_forward_method
3635

3736
from .config import ModeloptBaseRule, RulesDict
@@ -359,14 +358,10 @@ class DynamicModule(nn.Module):
359358
should ensure only to expose ``hparams`` in the outermost class and handle other ``hparams``
360359
internally including ``hparams`` of child modules that are exposed on their own usually
361360
(e.g. block module implementations containing DynamicLinear).
362-
363-
In addition, the class also provides ``parallel_state`` attribute that can be used to access
364-
the parallel state of the module.
365361
"""
366362

367363
# this is needed to store the special attributes for dynamic modules
368364
_dm_attribute_manager: _DMAttributeManager
369-
_parallel_state: ParallelState
370365

371366
def __init__(self, *args, **kwargs):
372367
"""Initializing a dynamic module is not allowed!"""
@@ -657,10 +652,6 @@ def bind_forward_method_if_needed(self):
657652
# setup new hparams and dynamic attributes
658653
module._setup(**setup_kwargs)
659654

660-
# setup parallel state now that the module is converted
661-
if module.parallel_state is None:
662-
module._initialize_parallel_state()
663-
664655
return module
665656

666657
def _setup(self, **setup_kwargs: Any):
@@ -867,36 +858,6 @@ def original_cls(self) -> type[nn.Module]:
867858
"""
868859
return self._get_dm_attribute_manager().og_cls
869860

870-
@property
871-
def parallel_state(self) -> ParallelState | None:
872-
"""Return the parallel state of the dynamic module."""
873-
return getattr(self, "_parallel_state", None)
874-
875-
@parallel_state.setter
876-
def parallel_state(self, parallel_state: ParallelState):
877-
"""Set the parallel state of the dynamic module."""
878-
assert isinstance(parallel_state, ParallelState), (
879-
"parallel_state must be a ParallelState object!"
880-
)
881-
self._parallel_state = parallel_state
882-
883-
def _initialize_parallel_state(self):
884-
"""Initialize the parallel state of the dynamic module.
885-
886-
This method is called only if the `DynamicModule` does not have a `parallel_state` attribute
887-
after `_setup` is called.
888-
"""
889-
if torch.distributed.is_initialized():
890-
warnings.warn(
891-
f"Distributed training is initialized but no parallel_state is set for {type(self)}. "
892-
"Using default parallel_state which has data_parallel_group set to the default process group and "
893-
"tensor_parallel_group is unspecified. "
894-
"If you are using tensor parallelism for this module, you should set the parallel_state "
895-
"in its `_setup` method."
896-
)
897-
898-
self.parallel_state = ParallelState(data_parallel_group=None)
899-
900861
def get_original_cls_by_level(self, level: int = -1) -> type[nn.Module]:
901862
"""Return the original class of the dynamic module.
902863

modelopt/torch/quantization/nn/modules/quant_module.py

Lines changed: 52 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,10 +17,13 @@
1717

1818
import contextlib
1919
import warnings
20+
from typing import Any
2021

2122
import torch
23+
import torch.nn as nn
2224

2325
from modelopt.torch.opt.dynamic import DynamicModule, _DMRegistryCls
26+
from modelopt.torch.utils.distributed import ParallelState
2427

2528
from ...tensor_quant import QUANT_DESC_8BIT_PER_TENSOR
2629
from ...utils import is_torch_export_mode
@@ -35,7 +38,55 @@
3538

3639

3740
class QuantModule(DynamicModule):
38-
"""A base class for quantized modules."""
41+
"""A base class for quantized modules.
42+
43+
In addition, the class also provides ``parallel_state`` attribute that can be used to access
44+
the parallel state of the module.
45+
"""
46+
47+
_parallel_state: ParallelState
48+
49+
@classmethod
50+
@torch.no_grad()
51+
def convert(cls, module: nn.Module, **setup_kwargs: Any) -> "QuantModule":
52+
"""Convert the module to a dynamic module."""
53+
module = super().convert(module, **setup_kwargs)
54+
55+
# setup parallel state now that the module is converted
56+
if module.parallel_state is None:
57+
module._initialize_parallel_state()
58+
59+
return module
60+
61+
@property
62+
def parallel_state(self) -> ParallelState | None:
63+
"""Return the parallel state of the quant module."""
64+
return getattr(self, "_parallel_state", None)
65+
66+
@parallel_state.setter
67+
def parallel_state(self, parallel_state: ParallelState):
68+
"""Set the parallel state of the dynamic module."""
69+
assert isinstance(parallel_state, ParallelState), (
70+
"parallel_state must be a ParallelState object!"
71+
)
72+
self._parallel_state = parallel_state
73+
74+
def _initialize_parallel_state(self):
75+
"""Initialize the parallel state of the dynamic module.
76+
77+
This method is called only if the `QuantModule` does not have a `parallel_state` attribute
78+
after `_setup` is called.
79+
"""
80+
if torch.distributed.is_initialized():
81+
warnings.warn(
82+
f"Distributed training is initialized but no parallel_state is set for {type(self)}. "
83+
"Using default parallel_state which has data_parallel_group set to the default process group and "
84+
"tensor_parallel_group is unspecified. "
85+
"If you are using tensor parallelism for this module, you should set the parallel_state "
86+
"in its `_setup` method."
87+
)
88+
89+
self.parallel_state = ParallelState(data_parallel_group=None)
3990

4091
def modelopt_post_restore(self, prefix: str = ""):
4192
"""Post-restore to correctly configure the TensorQuantizer states.

0 commit comments

Comments
 (0)