Skip to content

Commit e3e0d61

Browse files
minor
Signed-off-by: Keval Morabia <28916987+kevalmorabia97@users.noreply.github.com>
1 parent 4fee933 commit e3e0d61

3 files changed

Lines changed: 16 additions & 15 deletions

File tree

examples/megatron_bridge/README.md

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -43,9 +43,8 @@ Example usage for manually pruning to a specific architecture using following de
4343
```bash
4444
torchrun --nproc_per_node 2 /opt/Megatron-Bridge/3rdparty/Model-Optimizer/examples/megatron_bridge/prune_minitron.py \
4545
--hf_model_name_or_path Qwen/Qwen3-8B \
46-
--prune_export_config '{"hidden_size": 3072, "ffn_hidden_size": 9216}' \
47-
--hparams_to_skip num_attention_heads \
48-
--output_hf_path /tmp/Qwen3-8B-Pruned-6B
46+
--prune_export_config '{"hidden_size": 3584, "ffn_hidden_size": 9216}' \
47+
--output_hf_path /tmp/Qwen3-8B-Pruned-6B-manual
4948
```
5049

5150
To see the full usage for advanced configurations, run:

examples/megatron_bridge/prune_minitron.py

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,7 @@ def get_args() -> argparse.Namespace:
9191

9292
# Pruning parameters
9393
parser.add_argument(
94-
"--prune_intermediate_checkpoint",
94+
"--prune_intermediate_ckpt",
9595
type=str,
9696
default=None,
9797
help=(
@@ -169,18 +169,16 @@ def get_args() -> argparse.Namespace:
169169
args = parser.parse_args()
170170

171171
# Post-process arguments
172-
if args.prune_intermediate_checkpoint is None:
172+
if args.prune_intermediate_ckpt is None:
173173
if args.output_megatron_path:
174-
args.prune_intermediate_checkpoint = (
174+
args.prune_intermediate_ckpt = (
175175
f"{args.output_megatron_path}/modelopt_pruning_scores.pth"
176176
)
177177
elif args.output_hf_path:
178-
args.prune_intermediate_checkpoint = (
179-
f"{args.output_hf_path}/modelopt_pruning_scores.pth"
180-
)
178+
args.prune_intermediate_ckpt = f"{args.output_hf_path}/modelopt_pruning_scores.pth"
181179
print_rank_0(
182180
"No checkpoint provided to cache intermediate pruning scores. "
183-
f"Setting to: {args.prune_intermediate_checkpoint}"
181+
f"Setting to: {args.prune_intermediate_ckpt}"
184182
)
185183

186184
if args.prune_export_config:
@@ -247,7 +245,7 @@ def main(args: argparse.Namespace):
247245

248246
pruning_config = {
249247
"forward_loop": forward_loop,
250-
"checkpoint": args.prune_intermediate_checkpoint,
248+
"checkpoint": args.prune_intermediate_ckpt,
251249
}
252250
if args.prune_target_params is not None:
253251
# Restrict search space to a smaller set of candidates

modelopt/torch/quantization/nn/modules/quant_module.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,10 @@
1717

1818
import contextlib
1919
import warnings
20+
from typing import Any
2021

2122
import torch
23+
import torch.nn as nn
2224

2325
from modelopt.torch.opt.dynamic import DynamicModule, _DMRegistryCls
2426
from modelopt.torch.utils.distributed import ParallelState
@@ -44,9 +46,11 @@ class QuantModule(DynamicModule):
4446

4547
_parallel_state: ParallelState
4648

47-
def convert(self, *args, **kwargs):
49+
@classmethod
50+
@torch.no_grad()
51+
def convert(cls, module: nn.Module, **setup_kwargs: Any) -> "QuantModule":
4852
"""Convert the module to a dynamic module."""
49-
module = super().convert(*args, **kwargs)
53+
module = super().convert(module, **setup_kwargs)
5054

5155
# setup parallel state now that the module is converted
5256
if module.parallel_state is None:
@@ -56,7 +60,7 @@ def convert(self, *args, **kwargs):
5660

5761
@property
5862
def parallel_state(self) -> ParallelState | None:
59-
"""Return the parallel state of the dynamic module."""
63+
"""Return the parallel state of the quant module."""
6064
return getattr(self, "_parallel_state", None)
6165

6266
@parallel_state.setter
@@ -70,7 +74,7 @@ def parallel_state(self, parallel_state: ParallelState):
7074
def _initialize_parallel_state(self):
7175
"""Initialize the parallel state of the dynamic module.
7276
73-
This method is called only if the `DynamicModule` does not have a `parallel_state` attribute
77+
This method is called only if the `QuantModule` does not have a `parallel_state` attribute
7478
after `_setup` is called.
7579
"""
7680
if torch.distributed.is_initialized():

0 commit comments

Comments
 (0)