Skip to content

Commit 8f71a83

Browse files
committed
[None][test] add DeepSeek V4 Flash AutoDeploy smoke
Signed-off-by: Balamurugan Marimuthu <246387390+bmarimuthu-nv@users.noreply.github.com>
1 parent 8183123 commit 8f71a83

5 files changed

Lines changed: 151 additions & 12 deletions

File tree

tensorrt_llm/_torch/auto_deploy/custom_ops/fused_moe/mxfp4_moe.py

Lines changed: 18 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@
4949
)
5050
_E8M0_EXPONENT_BIAS = 127
5151
_MXFP4_BLOCK_SIZE = 32
52+
_TORCH_MXFP4_ROUTED_MOE_TOKEN_CHUNK = 16
5253

5354
# Prepared (swizzled) triton_kernels tensors; typed as ``object`` so the module
5455
# imports without ``triton_kernels``.
@@ -319,18 +320,23 @@ def _run_torch_mxfp4_from_routing_slots(
319320

320321
x_for_bmm = x.unsqueeze(-1)
321322
for route_idx in range(local_expert_idx.shape[1]):
322-
expert_idx = local_expert_idx[:, route_idx]
323-
gate_up = torch.bmm(gate_up_weight.index_select(0, expert_idx), x_for_bmm).squeeze(-1)
324-
gate_up = gate_up + gate_up_bias.index_select(0, expert_idx).to(torch.float32)
325-
inter = _apply_swiglu(gate_up, alpha, limit, gate_up_order, swiglu_mode)
326-
expert_output = torch.bmm(
327-
down_weight.index_select(0, expert_idx), inter.unsqueeze(-1)
328-
).squeeze(-1)
329-
expert_output = expert_output + down_bias.index_select(0, expert_idx).to(torch.float32)
330-
route_scale = routing_weights[:, route_idx, None] * valid_route[:, route_idx, None].to(
331-
torch.float32
332-
)
333-
output = output + expert_output * route_scale
323+
for start in range(0, x.shape[0], _TORCH_MXFP4_ROUTED_MOE_TOKEN_CHUNK):
324+
end = min(start + _TORCH_MXFP4_ROUTED_MOE_TOKEN_CHUNK, x.shape[0])
325+
token_slice = slice(start, end)
326+
expert_idx = local_expert_idx[token_slice, route_idx]
327+
gate_up = torch.bmm(
328+
gate_up_weight.index_select(0, expert_idx), x_for_bmm[token_slice]
329+
).squeeze(-1)
330+
gate_up = gate_up + gate_up_bias.index_select(0, expert_idx).to(torch.float32)
331+
inter = _apply_swiglu(gate_up, alpha, limit, gate_up_order, swiglu_mode)
332+
expert_output = torch.bmm(
333+
down_weight.index_select(0, expert_idx), inter.unsqueeze(-1)
334+
).squeeze(-1)
335+
expert_output = expert_output + down_bias.index_select(0, expert_idx).to(torch.float32)
336+
route_scale = routing_weights[token_slice, route_idx, None] * valid_route[
337+
token_slice, route_idx, None
338+
].to(torch.float32)
339+
output[token_slice] = output[token_slice] + expert_output * route_scale
334340

335341
return output.reshape(*leading_shape, hidden_size).to(output_dtype)
336342

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-License-Identifier: Apache-2.0
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
transforms:
17+
apply_sharding_hints:
18+
dist_mapping:
19+
tp: 4
20+
moe_ep: 4
21+
moe_tp: 1
22+
moe_cluster: 1

tests/integration/defs/accuracy/test_llm_api_autodeploy.py

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
from test_common.llm_data import hf_id_to_local_model_dir
2525

2626
from tensorrt_llm._torch.auto_deploy import LLM as AutoDeployLLM
27+
from tensorrt_llm.evaluate import GSM8K as GSM8KEvaluator
2728
from tensorrt_llm.llmapi import Eagle3DecodingConfig
2829
from tensorrt_llm.quantization import QuantAlgo
2930
from tensorrt_llm.sampling_params import SamplingParams
@@ -35,6 +36,7 @@
3536
'model_registry' / 'configs')
3637
_AD_MODEL_REGISTRY_DIR = Path(
3738
get_llm_root()) / 'examples' / 'auto_deploy' / 'model_registry'
39+
_ACCURACY_CONFIGS_DIR = Path(__file__).resolve().parent / "configs"
3840

3941

4042
def _load_ad_config(config_name):
@@ -1567,6 +1569,59 @@ def test_autodeploy_from_registry(self, model_name, config_overrides, tasks,
15671569
raise type(e)(f"[{task_cls.__name__}] {e}") from None
15681570

15691571

1572+
class TestDeepSeekV4Flash(LlmapiAccuracyTestHarness):
1573+
MODEL_NAME = "deepseek-ai/DeepSeek-V4-Flash"
1574+
WORLD_SIZE = 4
1575+
YAML_EXTRA = [
1576+
str(_AD_CONFIGS_DIR / "dashboard_default.yaml"),
1577+
str(_AD_CONFIGS_DIR / "world_size_4.yaml"),
1578+
str(_AD_CONFIGS_DIR / "deepseek_v4_flash.yaml"),
1579+
str(_ACCURACY_CONFIGS_DIR / "deepseek_v4_flash_4gpu_smoke.yaml"),
1580+
]
1581+
GSM8K_NUM_SAMPLES = 15
1582+
GSM8K_NUM_FEWSHOT = 0
1583+
GSM8K_MAX_INPUT_LEN = 1024
1584+
GSM8K_MAX_OUTPUT_LEN = 128
1585+
GSM8K_MIN_ACCURACY = 40.0
1586+
1587+
def get_default_sampling_params(self):
1588+
return SamplingParams(end_id=None,
1589+
pad_id=None,
1590+
max_tokens=self.GSM8K_MAX_OUTPUT_LEN,
1591+
n=1,
1592+
use_beam_search=False)
1593+
1594+
@pytest.mark.skip_less_device(4)
1595+
@pytest.mark.skip_less_device_memory(80000)
1596+
def test_gsm8k_smoke(self):
1597+
if get_device_count() < self.WORLD_SIZE:
1598+
pytest.skip(
1599+
f"DeepSeek V4 Flash smoke requires {self.WORLD_SIZE} GPUs")
1600+
1601+
with AutoDeployLLM(model=self.MODEL_NAME,
1602+
tokenizer=self.MODEL_NAME,
1603+
world_size=self.WORLD_SIZE,
1604+
yaml_extra=self.YAML_EXTRA,
1605+
max_seq_len=self.GSM8K_MAX_INPUT_LEN +
1606+
self.GSM8K_MAX_OUTPUT_LEN,
1607+
trust_remote_code=True) as llm:
1608+
task = GSM8KEvaluator(dataset_path=GSM8K.DATASET_DIR,
1609+
num_samples=self.GSM8K_NUM_SAMPLES,
1610+
random_seed=0)
1611+
for task_obj in task.task_dict.values():
1612+
task_obj.set_config("num_fewshot", self.GSM8K_NUM_FEWSHOT)
1613+
score = task.evaluate(
1614+
llm,
1615+
sampling_params=self.get_default_sampling_params(),
1616+
scores_filter="exact_match,flexible-extract",
1617+
)
1618+
1619+
assert score >= self.GSM8K_MIN_ACCURACY, (
1620+
f"DeepSeek V4 Flash GSM8K smoke accuracy {score:.2f} is below "
1621+
f"{self.GSM8K_MIN_ACCURACY:.2f} on {self.GSM8K_NUM_SAMPLES} samples"
1622+
)
1623+
1624+
15701625
# =============================================================================
15711626
# IR Sharding Path Tests
15721627
# =============================================================================

tests/integration/test_lists/test-db/l0_dgx_h100.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -378,6 +378,7 @@ l0_dgx_h100:
378378
- accuracy/test_llm_api_autodeploy.py::TestQwen3_5_397B_MoE::test_bf16_small[4]
379379
- accuracy/test_llm_api_autodeploy.py::TestGemma4MoE::test_bf16
380380
- accuracy/test_llm_api_autodeploy.py::TestMiniMaxM2::test_finegrained_fp8
381+
- accuracy/test_llm_api_autodeploy.py::TestDeepSeekV4Flash::test_gsm8k_smoke
381382
# ------------- AutoDeploy Backend Stages L1 / Nightly only ---------------
382383
- condition:
383384
ranges:

tests/unittest/auto_deploy/singlegpu/custom_ops/moe/test_torch_mxfp4_moe.py

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -592,6 +592,61 @@ def test_torch_mxfp4_moe_from_routing_matches_deepseek_layout_reference() -> Non
592592
torch.testing.assert_close(actual, expected, rtol=1e-5, atol=1e-5)
593593

594594

595+
def test_torch_mxfp4_moe_from_routing_matches_reference_across_token_chunks() -> None:
596+
num_experts = 3
597+
hidden_size = 32
598+
intermediate_size = 32
599+
alpha = 1.0
600+
limit = 0.75
601+
num_tokens = 20
602+
x = torch.linspace(-0.3, 0.35, steps=num_tokens * hidden_size, dtype=torch.float32).reshape(
603+
num_tokens, hidden_size
604+
)
605+
token_ids = torch.arange(num_tokens, dtype=torch.int64)
606+
selected_experts = torch.stack(
607+
(token_ids % num_experts, (token_ids + 1) % num_experts),
608+
dim=1,
609+
)
610+
routing_weights = torch.stack(
611+
(
612+
torch.linspace(0.15, 0.45, steps=num_tokens),
613+
torch.linspace(0.4, 0.1, steps=num_tokens),
614+
),
615+
dim=1,
616+
)
617+
packed, w1_weight, w2_weight, w3_weight = _deepseek_packed_params_from_layout(num_experts)
618+
gate_up_bias = torch.zeros((num_experts, 2 * intermediate_size), dtype=torch.float32)
619+
down_bias = torch.zeros((num_experts, hidden_size), dtype=torch.float32)
620+
621+
actual = torch.ops.auto_deploy.torch_mxfp4_moe_from_routing(
622+
x,
623+
selected_experts,
624+
routing_weights,
625+
packed.gate_up_blocks,
626+
gate_up_bias,
627+
packed.gate_up_scales,
628+
alpha,
629+
limit,
630+
packed.down_blocks,
631+
down_bias,
632+
packed.down_scales,
633+
"up_gate",
634+
"deepseek",
635+
)
636+
expected = _dense_deepseek_routing_reference(
637+
x,
638+
selected_experts,
639+
routing_weights,
640+
w1_weight,
641+
w2_weight,
642+
w3_weight,
643+
alpha=alpha,
644+
limit=limit,
645+
)
646+
647+
torch.testing.assert_close(actual, expected, rtol=1e-5, atol=1e-5)
648+
649+
595650
def test_torch_mxfp4_moe_from_routing_ep_partitions_deepseek_layout_experts() -> None:
596651
num_experts = 5
597652
ep_size = 3

0 commit comments

Comments
 (0)