|
31 | 31 | from torch.nn.parameter import Parameter |
32 | 32 |
|
33 | 33 | from modelopt.torch.utils import get_unwrapped_name, is_channels_last, unwrap_model |
34 | | -from modelopt.torch.utils.distributed import ParallelState |
35 | 34 | from modelopt.torch.utils.network import bind_forward_method |
36 | 35 |
|
37 | 36 | from .config import ModeloptBaseRule, RulesDict |
@@ -359,14 +358,10 @@ class DynamicModule(nn.Module): |
359 | 358 | should ensure only to expose ``hparams`` in the outermost class and handle other ``hparams`` |
360 | 359 | internally including ``hparams`` of child modules that are exposed on their own usually |
361 | 360 | (e.g. block module implementations containing DynamicLinear). |
362 | | -
|
363 | | - In addition, the class also provides ``parallel_state`` attribute that can be used to access |
364 | | - the parallel state of the module. |
365 | 361 | """ |
366 | 362 |
|
367 | 363 | # this is needed to store the special attributes for dynamic modules |
368 | 364 | _dm_attribute_manager: _DMAttributeManager |
369 | | - _parallel_state: ParallelState |
370 | 365 |
|
371 | 366 | def __init__(self, *args, **kwargs): |
372 | 367 | """Initializing a dynamic module is not allowed!""" |
@@ -657,10 +652,6 @@ def bind_forward_method_if_needed(self): |
657 | 652 | # setup new hparams and dynamic attributes |
658 | 653 | module._setup(**setup_kwargs) |
659 | 654 |
|
660 | | - # setup parallel state now that the module is converted |
661 | | - if module.parallel_state is None: |
662 | | - module._initialize_parallel_state() |
663 | | - |
664 | 655 | return module |
665 | 656 |
|
666 | 657 | def _setup(self, **setup_kwargs: Any): |
@@ -867,36 +858,6 @@ def original_cls(self) -> type[nn.Module]: |
867 | 858 | """ |
868 | 859 | return self._get_dm_attribute_manager().og_cls |
869 | 860 |
|
870 | | - @property |
871 | | - def parallel_state(self) -> ParallelState | None: |
872 | | - """Return the parallel state of the dynamic module.""" |
873 | | - return getattr(self, "_parallel_state", None) |
874 | | - |
875 | | - @parallel_state.setter |
876 | | - def parallel_state(self, parallel_state: ParallelState): |
877 | | - """Set the parallel state of the dynamic module.""" |
878 | | - assert isinstance(parallel_state, ParallelState), ( |
879 | | - "parallel_state must be a ParallelState object!" |
880 | | - ) |
881 | | - self._parallel_state = parallel_state |
882 | | - |
883 | | - def _initialize_parallel_state(self): |
884 | | - """Initialize the parallel state of the dynamic module. |
885 | | -
|
886 | | - This method is called only if the `DynamicModule` does not have a `parallel_state` attribute |
887 | | - after `_setup` is called. |
888 | | - """ |
889 | | - if torch.distributed.is_initialized(): |
890 | | - warnings.warn( |
891 | | - f"Distributed training is initialized but no parallel_state is set for {type(self)}. " |
892 | | - "Using default parallel_state which has data_parallel_group set to the default process group and " |
893 | | - "tensor_parallel_group is unspecified. " |
894 | | - "If you are using tensor parallelism for this module, you should set the parallel_state " |
895 | | - "in its `_setup` method." |
896 | | - ) |
897 | | - |
898 | | - self.parallel_state = ParallelState(data_parallel_group=None) |
899 | | - |
900 | 861 | def get_original_cls_by_level(self, level: int = -1) -> type[nn.Module]: |
901 | 862 | """Return the original class of the dynamic module. |
902 | 863 |
|
|
0 commit comments