Skip to content

Commit 4fee933

Browse files
Move parallel state init and warnings to Quant DynamicModule
Signed-off-by: Keval Morabia <28916987+kevalmorabia97@users.noreply.github.com>
1 parent 883af35 commit 4fee933

3 files changed

Lines changed: 49 additions & 41 deletions

File tree

CHANGELOG.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -226,7 +226,7 @@ NVIDIA Model Optimizer Changelog (Linux)
226226
- Add support for UNet ONNX quantization.
227227
- Enable ``concat_elimination`` pass by default to improve the performance of quantized ONNX models.
228228
- Enable Redundant Cast elimination pass by default in :meth:`moq.quantize <modelopt.onnx.quantization.quantize>`.
229-
- Add new attribute ``parallel_state`` to :class:`DynamicModule <modelopt.torch.opt.dynamic.DynamicModule>` to support distributed parallelism such as data parallel and tensor parallel.
229+
- Add new attribute ``parallel_state`` to :class:`QuantModule <modelopt.torch.quantization.nn.modules.quant_module.QuantModule>` to support distributed parallelism such as data parallel and tensor parallel.
230230
- Add MXFP8, NVFP4 quantized ONNX export support.
231231
- Add new example for torch quantization to ONNX for MXFP8, NVFP4 precision.
232232

modelopt/torch/opt/dynamic.py

Lines changed: 0 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,6 @@
3131
from torch.nn.parameter import Parameter
3232

3333
from modelopt.torch.utils import get_unwrapped_name, is_channels_last, unwrap_model
34-
from modelopt.torch.utils.distributed import ParallelState
3534
from modelopt.torch.utils.network import bind_forward_method
3635

3736
from .config import ModeloptBaseRule, RulesDict
@@ -359,14 +358,10 @@ class DynamicModule(nn.Module):
359358
should ensure only to expose ``hparams`` in the outermost class and handle other ``hparams``
360359
internally including ``hparams`` of child modules that are exposed on their own usually
361360
(e.g. block module implementations containing DynamicLinear).
362-
363-
In addition, the class also provides ``parallel_state`` attribute that can be used to access
364-
the parallel state of the module.
365361
"""
366362

367363
# this is needed to store the special attributes for dynamic modules
368364
_dm_attribute_manager: _DMAttributeManager
369-
_parallel_state: ParallelState
370365

371366
def __init__(self, *args, **kwargs):
372367
"""Initializing a dynamic module is not allowed!"""
@@ -657,10 +652,6 @@ def bind_forward_method_if_needed(self):
657652
# setup new hparams and dynamic attributes
658653
module._setup(**setup_kwargs)
659654

660-
# setup parallel state now that the module is converted
661-
if module.parallel_state is None:
662-
module._initialize_parallel_state()
663-
664655
return module
665656

666657
def _setup(self, **setup_kwargs: Any):
@@ -867,36 +858,6 @@ def original_cls(self) -> type[nn.Module]:
867858
"""
868859
return self._get_dm_attribute_manager().og_cls
869860

870-
@property
871-
def parallel_state(self) -> ParallelState | None:
872-
"""Return the parallel state of the dynamic module."""
873-
return getattr(self, "_parallel_state", None)
874-
875-
@parallel_state.setter
876-
def parallel_state(self, parallel_state: ParallelState):
877-
"""Set the parallel state of the dynamic module."""
878-
assert isinstance(parallel_state, ParallelState), (
879-
"parallel_state must be a ParallelState object!"
880-
)
881-
self._parallel_state = parallel_state
882-
883-
def _initialize_parallel_state(self):
884-
"""Initialize the parallel state of the dynamic module.
885-
886-
This method is called only if the `DynamicModule` does not have a `parallel_state` attribute
887-
after `_setup` is called.
888-
"""
889-
if torch.distributed.is_initialized():
890-
warnings.warn(
891-
f"Distributed training is initialized but no parallel_state is set for {type(self)}. "
892-
"Using default parallel_state which has data_parallel_group set to the default process group and "
893-
"tensor_parallel_group is unspecified. "
894-
"If you are using tensor parallelism for this module, you should set the parallel_state "
895-
"in its `_setup` method."
896-
)
897-
898-
self.parallel_state = ParallelState(data_parallel_group=None)
899-
900861
def get_original_cls_by_level(self, level: int = -1) -> type[nn.Module]:
901862
"""Return the original class of the dynamic module.
902863

modelopt/torch/quantization/nn/modules/quant_module.py

Lines changed: 48 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
import torch
2222

2323
from modelopt.torch.opt.dynamic import DynamicModule, _DMRegistryCls
24+
from modelopt.torch.utils.distributed import ParallelState
2425

2526
from ...tensor_quant import QUANT_DESC_8BIT_PER_TENSOR
2627
from ...utils import is_torch_export_mode
@@ -35,7 +36,53 @@
3536

3637

3738
class QuantModule(DynamicModule):
38-
"""A base class for quantized modules."""
39+
"""A base class for quantized modules.
40+
41+
In addition, the class also provides ``parallel_state`` attribute that can be used to access
42+
the parallel state of the module.
43+
"""
44+
45+
_parallel_state: ParallelState
46+
47+
def convert(self, *args, **kwargs):
48+
"""Convert the module to a dynamic module."""
49+
module = super().convert(*args, **kwargs)
50+
51+
# setup parallel state now that the module is converted
52+
if module.parallel_state is None:
53+
module._initialize_parallel_state()
54+
55+
return module
56+
57+
@property
58+
def parallel_state(self) -> ParallelState | None:
59+
"""Return the parallel state of the dynamic module."""
60+
return getattr(self, "_parallel_state", None)
61+
62+
@parallel_state.setter
63+
def parallel_state(self, parallel_state: ParallelState):
64+
"""Set the parallel state of the dynamic module."""
65+
assert isinstance(parallel_state, ParallelState), (
66+
"parallel_state must be a ParallelState object!"
67+
)
68+
self._parallel_state = parallel_state
69+
70+
def _initialize_parallel_state(self):
71+
"""Initialize the parallel state of the dynamic module.
72+
73+
This method is called only if the `DynamicModule` does not have a `parallel_state` attribute
74+
after `_setup` is called.
75+
"""
76+
if torch.distributed.is_initialized():
77+
warnings.warn(
78+
f"Distributed training is initialized but no parallel_state is set for {type(self)}. "
79+
"Using default parallel_state which has data_parallel_group set to the default process group and "
80+
"tensor_parallel_group is unspecified. "
81+
"If you are using tensor parallelism for this module, you should set the parallel_state "
82+
"in its `_setup` method."
83+
)
84+
85+
self.parallel_state = ParallelState(data_parallel_group=None)
3986

4087
def modelopt_post_restore(self, prefix: str = ""):
4188
"""Post-restore to correctly configure the TensorQuantizer states.

0 commit comments

Comments
 (0)