Skip to content

Commit da2f5dc

Browse files
minor
Signed-off-by: Keval Morabia <28916987+kevalmorabia97@users.noreply.github.com>
1 parent 74184b8 commit da2f5dc

3 files changed

Lines changed: 7 additions & 2 deletions

File tree

examples/megatron_bridge/prune_minitron.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -240,6 +240,7 @@ def main(args: argparse.Namespace):
240240
"seq_length": args.seq_length,
241241
},
242242
init_model_parallel=True,
243+
moe_grouped_gemm=False,
243244
)
244245
print_rank_0(f"\nPruning model (showing PP rank0): {unwrapped_model}")
245246
print_rank_0(

examples/pruning/README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,7 @@ bridge, provider, model, unwrapped_model, tokenizer = load_mbridge_model_from_hf
6464
"pipeline_dtype": torch.bfloat16,
6565
"seq_length": 4096,
6666
},
67+
moe_grouped_gemm=False,
6768
)
6869

6970
# Set up the forward loop to run on 1024 train samples

modelopt/torch/utils/plugins/mbridge.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,7 @@ def load_mbridge_model_from_hf(
5959
trust_remote_code: bool = False,
6060
provider_overrides: dict[str, Any] | None = None,
6161
init_model_parallel: bool = True,
62+
moe_grouped_gemm: bool = True,
6263
) -> tuple[
6364
AutoBridge,
6465
GPTModelProvider | MambaModelProvider,
@@ -73,6 +74,8 @@ def load_mbridge_model_from_hf(
7374
trust_remote_code: Whether to trust remote code.
7475
provider_overrides: Overrides for the provider.
7576
init_model_parallel: Whether to initialize model parallel.
77+
moe_grouped_gemm: Whether to use grouped GEMM for MoE.
78+
Pruning does not support grouped GEMM yet.
7679
7780
Returns:
7881
A tuple of (bridge, provider, model, unwrapped_model, tokenizer).
@@ -94,11 +97,11 @@ def load_mbridge_model_from_hf(
9497

9598
# disable moe_grouped_gemm in default TE spec until its supported
9699
if isinstance(provider, MambaModelProvider):
97-
provider.mamba_stack_spec = get_te_mamba_stack_spec(moe_grouped_gemm=False)
100+
provider.mamba_stack_spec = get_te_mamba_stack_spec(moe_grouped_gemm=moe_grouped_gemm)
98101
else:
99102
provider.transformer_layer_spec = get_gpt_layer_with_transformer_engine_spec(
100103
num_experts=provider.num_moe_experts,
101-
moe_grouped_gemm=False,
104+
moe_grouped_gemm=moe_grouped_gemm,
102105
qk_layernorm=provider.qk_layernorm,
103106
)
104107
provider.finalize()

0 commit comments

Comments
 (0)