Skip to content

Commit bfefc5e

Browse files
committed
[None][fix] resolve DeepSeek V4 rebase fallout
Signed-off-by: Balamurugan Marimuthu <246387390+bmarimuthu-nv@users.noreply.github.com>
1 parent 6008be7 commit bfefc5e

3 files changed

Lines changed: 51 additions & 25 deletions

File tree

tensorrt_llm/_torch/auto_deploy/transform/library/fused_moe_mxfp4.py

Lines changed: 36 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@
3939
_WEIGHT_ALIGNMENT = 128
4040
_MXFP4_DEFAULT_EXPERT_BLOCK_SIZE = 32
4141
_MXFP4_SUPPORTED_EXPERT_BLOCK_SIZE = 32
42+
_MXFP4_VALUES_PER_BYTE = 2
4243
_MXFP4_LAYOUT_ARG_NAMES = (
4344
"gate_up_blocks",
4445
"gate_up_scales",
@@ -375,12 +376,24 @@ def _register_existing_mxfp4_expert_layout_hooks(
375376
return num_hooks
376377

377378

379+
def _mxfp4_block_count(name: str, dim: int, expert_block_size: int) -> int:
380+
if dim <= 0:
381+
raise ValueError(f"MXFP4 expert {name} should be positive, got {dim}.")
382+
if dim % expert_block_size != 0:
383+
raise ValueError(
384+
f"MXFP4 expert {name} should be divisible by expert_block_size="
385+
f"{expert_block_size}, got {dim}."
386+
)
387+
return dim // expert_block_size
388+
389+
378390
def _register_mxfp4_expert_params(
379391
gm: GraphModule,
380392
gate_up_w_name: str,
381393
gate_up_b_name: str,
382394
down_w_name: str,
383395
down_b_name: str,
396+
expert_block_size: int = _MXFP4_DEFAULT_EXPERT_BLOCK_SIZE,
384397
) -> Tuple[str, str, str, str]:
385398
"""Create (if missing) the four MXFP4 params under the experts module and return their full names.
386399
@@ -404,9 +417,9 @@ def _register_mxfp4_expert_params(
404417
# Fallback: use down bias last dim
405418
H = int(dn_b.shape[1])
406419

407-
# Compute block dims (assume divisible; zero-init anyway)
408-
H_blk = max(1, H // 32)
409-
I_blk = max(1, In // 32)
420+
packed_block_width = expert_block_size // _MXFP4_VALUES_PER_BYTE
421+
H_blk = _mxfp4_block_count("hidden_size", H, expert_block_size)
422+
I_blk = _mxfp4_block_count("intermediate_size", In, expert_block_size)
410423

411424
experts_mod, experts_path, _ = get_submodule_of_param(gm, gate_up_w_name)
412425

@@ -421,9 +434,13 @@ def _register_mxfp4_expert_params(
421434
# (meta in the normal meta-device build) so we don't materialize giant CPU
422435
# buffers before load.
423436
param_device = gu_w.device
424-
gu_blocks = torch.empty((E, 2 * In, H_blk, 16), dtype=torch.uint8, device=param_device)
437+
gu_blocks = torch.empty(
438+
(E, 2 * In, H_blk, packed_block_width), dtype=torch.uint8, device=param_device
439+
)
425440
gu_scales = torch.empty((E, 2 * In, H_blk), dtype=torch.uint8, device=param_device)
426-
dn_blocks = torch.empty((E, H, I_blk, 16), dtype=torch.uint8, device=param_device)
441+
dn_blocks = torch.empty(
442+
(E, H, I_blk, packed_block_width), dtype=torch.uint8, device=param_device
443+
)
427444
dn_scales = torch.empty((E, H, I_blk), dtype=torch.uint8, device=param_device)
428445

429446
experts_mod.register_parameter(gu_blocks_name, nn.Parameter(gu_blocks, requires_grad=False))
@@ -722,8 +739,7 @@ def _apply(
722739
skipped=True, num_matches=0, is_clean=True, has_valid_shapes=True
723740
)
724741
checkpoint_layout = _get_packed_mxfp4_expert_layout(qcfg)
725-
if checkpoint_layout is not None or qcfg.get("expert_block_size") is not None:
726-
_resolve_mxfp4_expert_block_size(qcfg, checkpoint_layout)
742+
expert_block_size = _resolve_mxfp4_expert_block_size(qcfg, checkpoint_layout)
727743
num_existing_hooks = 0
728744
if checkpoint_layout is not None:
729745
num_existing_hooks = _register_existing_mxfp4_expert_layout_hooks(
@@ -744,7 +760,9 @@ def _apply(
744760
ad_logger.info(f"quantize_mxfp4_moe: dispatching to backend={backend!r}")
745761

746762
if backend == "triton":
747-
gm, info = self._apply_triton(gm, cm, factory, shared_config)
763+
gm, info = self._apply_triton(
764+
gm, cm, factory, shared_config, expert_block_size=expert_block_size
765+
)
748766
elif backend == "trtllm":
749767
gm, info = self._apply_trtllm(gm, cm, factory, shared_config)
750768
else:
@@ -767,6 +785,8 @@ def _apply_triton(
767785
cm,
768786
factory,
769787
shared_config,
788+
*,
789+
expert_block_size: int = _MXFP4_DEFAULT_EXPERT_BLOCK_SIZE,
770790
) -> Tuple[GraphModule, TransformInfo]:
771791
"""Triton backend: graph rewrite to ``triton_mxfp4_moe``.
772792
@@ -820,7 +840,14 @@ def _apply_triton(
820840

821841
# Register MXFP4 params on experts
822842
gu_blocks_name, gu_scales_name, dn_blocks_name, dn_scales_name = (
823-
_register_mxfp4_expert_params(gm, gu_w_name, gu_b_name, dn_w_name, dn_b_name)
843+
_register_mxfp4_expert_params(
844+
gm,
845+
gu_w_name,
846+
gu_b_name,
847+
dn_w_name,
848+
dn_b_name,
849+
expert_block_size=expert_block_size,
850+
)
824851
)
825852

826853
# Alpha/limit (from dense call)

tests/unittest/auto_deploy/multigpu/transformations/library/test_apply_sharding_hints.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -35,9 +35,11 @@
3535
from tensorrt_llm._torch.auto_deploy.export import torch_export_to_gm
3636
from tensorrt_llm._torch.auto_deploy.models.quant_checkpoint_layout import QuantizedCheckpointLayout
3737
from tensorrt_llm._torch.auto_deploy.transform.interface import SharedConfig, Stages
38-
from tensorrt_llm._torch.auto_deploy.transform.library.mxfp4_moe import (
39-
InsertMXFP4MLP,
40-
MXFP4MLPConfig,
38+
from tensorrt_llm._torch.auto_deploy.transform.library.fused_moe_mxfp4 import (
39+
QuantizeMXFP4MOE as InsertMXFP4MLP,
40+
)
41+
from tensorrt_llm._torch.auto_deploy.transform.library.fused_moe_mxfp4 import (
42+
QuantizeMXFP4MOEConfig as MXFP4MLPConfig,
4143
)
4244
from tensorrt_llm._torch.auto_deploy.transform.library.sharding import _get_dist_ops
4345
from tensorrt_llm._torch.auto_deploy.transform.optimizer import InferenceOptimizer

tests/unittest/auto_deploy/singlegpu/custom_ops/moe/test_torch_mxfp4_moe.py

Lines changed: 10 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -23,9 +23,13 @@
2323
build_deepseek_v4_packed_mxfp4_experts_layout,
2424
)
2525
from tensorrt_llm._torch.auto_deploy.transform.interface import Stages # noqa: E402
26-
from tensorrt_llm._torch.auto_deploy.transform.library.mxfp4_moe import ( # noqa: E402
27-
InsertMXFP4MLP,
28-
MXFP4MLPConfig,
26+
from tensorrt_llm._torch.auto_deploy.transform.library.fused_moe_mxfp4 import ( # noqa: E402
27+
QuantizeMXFP4MOE as InsertMXFP4MLP,
28+
)
29+
from tensorrt_llm._torch.auto_deploy.transform.library.fused_moe_mxfp4 import (
30+
QuantizeMXFP4MOEConfig as MXFP4MLPConfig,
31+
)
32+
from tensorrt_llm._torch.auto_deploy.transform.library.fused_moe_mxfp4 import (
2933
_mxfp4_target_names_from_node,
3034
_resolve_mxfp4_expert_block_size,
3135
)
@@ -752,18 +756,11 @@ def test_torch_mxfp4_moe_from_routing_ep_allows_cuda_graph_capture() -> None:
752756
torch.testing.assert_close(actual, expected, rtol=1e-5, atol=1e-5)
753757

754758

755-
def test_mxfp4_transform_backend_selector_prefers_torch_for_checkpoint_layout() -> None:
756-
config = MXFP4MLPConfig(stage=Stages.PATTERN_MATCHER)
759+
def test_mxfp4_transform_backend_selector_respects_explicit_triton() -> None:
760+
config = MXFP4MLPConfig(stage=Stages.PATTERN_MATCHER, backend="triton")
757761
transform = InsertMXFP4MLP(config)
758762

759-
assert transform._resolve_backend({"quant_method": "mxfp4"}, None) == "triton"
760-
assert transform._resolve_backend({"expert_quant_method": "mxfp4"}, object()) == "torch"
761-
assert (
762-
InsertMXFP4MLP(
763-
MXFP4MLPConfig(stage=Stages.PATTERN_MATCHER, mxfp4_backend="triton")
764-
)._resolve_backend({"expert_quant_method": "mxfp4"}, object())
765-
== "triton"
766-
)
763+
assert transform._resolve_backend() == "triton"
767764

768765

769766
def test_mxfp4_target_names_from_node_uses_op_schema_names_for_kwargs() -> None:

0 commit comments

Comments
 (0)