Skip to content

Commit f39a42a

Browse files
committed
bench-moe enhancement: prune TEP/TTP forced-comm candidates at generation time; add DENSEGEMM+EP validation; remove spurious .cpu() in MXFP4 quantize_utils
Signed-off-by: guqiqi <29116997+guqiqi@users.noreply.github.com>
1 parent 2cc162b commit f39a42a

2 files changed

Lines changed: 35 additions & 5 deletions

File tree

tests/microbenchmarks/bench_moe/search.py

Lines changed: 32 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727
from tensorrt_llm.models.modeling_utils import QuantAlgo
2828

2929
from .backend import MoeBackendType, get_backend_class
30-
from .mapping import _resolve_mapping_layout
30+
from .mapping import _PARALLEL_MODE_LAYOUTS, _resolve_mapping_layout
3131
from .specs import _ALL_BACKENDS, _FORCED_COMM_ENV_VALUES, ConfigSpec, ModelSpec, SearchSpec
3232

3333
_FUSED_COMM_BACKENDS = frozenset({"MEGAMOE_DEEPGEMM"})
@@ -65,6 +65,23 @@ def _comm_axis_for_backend(backend: Any, comm_methods: Tuple[Any, ...]) -> Tuple
6565
return comm_methods
6666

6767

68+
def _comm_axis_for_parallel_mode(pmode: str, comm_methods: Tuple[Any, ...]) -> Tuple[Any, ...]:
69+
"""Collapse comm axis to AUTO for parallel modes without attention DP.
70+
71+
Non-AUTO forced comm methods require enable_attention_dp=True (see
72+
is_candidate_valid). TEP and TTP have enable_dp=False, so only AUTO
73+
is ever valid for them. Generating forced-comm candidates for these
74+
modes only produces prune rows — handle it at generation time instead.
75+
CUSTOM mode is passed through unchanged (validated separately).
76+
"""
77+
layout = _PARALLEL_MODE_LAYOUTS.get(str(pmode).upper())
78+
if layout is None:
79+
return comm_methods # CUSTOM: unknown layout, keep as-is
80+
if not layout["enable_attention_dp"]:
81+
return ("AUTO",)
82+
return comm_methods
83+
84+
6885
def expand_search(
6986
base_config: ConfigSpec,
7087
search: SearchSpec,
@@ -88,7 +105,13 @@ def expand_search(
88105
for backend, pmode, cgraph, combine in itertools.product(
89106
backends, parallel_modes, cuda_graph_options, combine_options
90107
):
91-
for comm in _comm_axis_for_backend(backend, comm_methods):
108+
effective_comm = _comm_axis_for_backend(backend, comm_methods)
109+
# For non-fused backends apply parallel-mode comm constraint at
110+
# generation time so TEP/TTP always get comm=AUTO instead of
111+
# generating forced-comm candidates that are immediately pruned.
112+
if effective_comm != ("NONE",):
113+
effective_comm = _comm_axis_for_parallel_mode(pmode, effective_comm)
114+
for comm in effective_comm:
92115
candidate = replace(
93116
base_config,
94117
backend=str(backend).upper(),
@@ -121,6 +144,13 @@ def is_candidate_valid(
121144
except ValueError as exc:
122145
return False, str(exc)
123146

147+
# DenseGEMM only supports TP; any EP configuration (TEP, DEP, custom ep>1) is unsupported.
148+
if config.backend.upper() == "DENSEGEMM" and moe_ep > 1:
149+
return False, (
150+
f"DENSEGEMM does not support EP (ep_size={moe_ep}); "
151+
"use TEP/DEP only with other backends"
152+
)
153+
124154
# Forced communication on non-DP / MoE-TP paths.
125155
forced = config.comm_method.upper()
126156
if forced not in ("AUTO", "NONE"):

tests/unittest/_torch/modules/moe/quantize_utils.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1695,21 +1695,21 @@ def create_weights(self, **quant_kwargs) -> Dict[str, torch.Tensor]:
16951695
w1_weight, None, scaling_vector_size, True
16961696
)
16971697
w1_sf_block_unswizzled = torch.ops.trtllm.block_scale_interleave_reverse(
1698-
w1_sf_block.cpu().view(intermediate_size, -1)
1698+
w1_sf_block.view(intermediate_size, -1)
16991699
)
17001700

17011701
w2_weight_mxfp4, w2_sf_block = torch.ops.trtllm.fp4_quantize(
17021702
w2_weight, None, scaling_vector_size, True
17031703
)
17041704
w2_sf_block_unswizzled = torch.ops.trtllm.block_scale_interleave_reverse(
1705-
w2_sf_block.cpu().view(hidden_size_out, -1)
1705+
w2_sf_block.view(hidden_size_out, -1)
17061706
)
17071707

17081708
w3_weight_mxfp4, w3_sf_block = torch.ops.trtllm.fp4_quantize(
17091709
w3_weight, None, scaling_vector_size, True
17101710
)
17111711
w3_sf_block_unswizzled = torch.ops.trtllm.block_scale_interleave_reverse(
1712-
w3_sf_block.cpu().view(intermediate_size, -1)
1712+
w3_sf_block.view(intermediate_size, -1)
17131713
)
17141714

17151715
weights[f"{expert_id}.w1.weight"] = w1_weight_mxfp4

0 commit comments

Comments
 (0)