Skip to content

Commit b72ee4f

Browse files
authored
[https://nvbugs/5973536][fix] Route DSA attention through MLA custom op for torch.compile compatibility (#12186)
Signed-off-by: Yi Zhang <187001205+yizhang-nv@users.noreply.github.com>
1 parent fe9e1a3 commit b72ee4f

5 files changed

Lines changed: 89 additions & 11 deletions

File tree

tensorrt_llm/_torch/modules/attention.py

Lines changed: 16 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -937,11 +937,17 @@ def mla_custom_op_inplace(
937937
latent_cache_gen: Optional[torch.Tensor],
938938
) -> None:
939939
metadata, mla_layer = extract_extra_attrs(layer_idx, "mla")
940-
mla_layer.forward_impl(position_ids,
941-
hidden_states,
942-
metadata,
943-
output=output,
944-
latent_cache_gen=latent_cache_gen)
940+
if mla_layer.is_dsa:
941+
mla_layer.forward_impl_with_dsa(position_ids,
942+
hidden_states,
943+
metadata,
944+
output=output)
945+
else:
946+
mla_layer.forward_impl(position_ids,
947+
hidden_states,
948+
metadata,
949+
output=output,
950+
latent_cache_gen=latent_cache_gen)
945951

946952

947953
def fp8_block_scaling_bmm_out(
@@ -2597,16 +2603,15 @@ def forward(
25972603

25982604
attn_output = self.create_output(hidden_states,
25992605
attn_metadata.num_contexts)
2600-
if self.is_dsa:
2606+
if self.register_to_config:
2607+
torch.ops.trtllm.mla_custom_op_inplace(
2608+
hidden_states, position_ids, self.layer_idx_str, attn_output,
2609+
None if self.is_dsa else latent_cache_gen)
2610+
elif self.is_dsa:
26012611
self.forward_impl_with_dsa(position_ids,
26022612
hidden_states,
26032613
attn_metadata,
26042614
output=attn_output)
2605-
elif self.register_to_config:
2606-
torch.ops.trtllm.mla_custom_op_inplace(hidden_states, position_ids,
2607-
self.layer_idx_str,
2608-
attn_output,
2609-
latent_cache_gen)
26102615
else:
26112616
self.forward_impl(position_ids,
26122617
hidden_states,

tests/integration/defs/accuracy/test_llm_api_pytorch.py

Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3084,6 +3084,73 @@ def test_fp8_blockscale(self, tp_size, pp_size, ep_size, mtp_nextn, fp8kv,
30843084
task = GSM8K(self.MODEL_NAME)
30853085
task.evaluate(llm)
30863086

3087+
@pytest.mark.skip_less_mpi_world_size(8)
3088+
@skip_pre_blackwell
3089+
@pytest.mark.parametrize(
3090+
"tp_size,pp_size,ep_size,mtp_nextn,attention_dp,max_batch_size,moe_backend,fp8kv,chunked_prefill",
3091+
[
3092+
(8, 1, 8, 0, True, 24, "CUTLASS", False, False),
3093+
(8, 1, 8, 3, False, 16, "TRTLLM", True, True),
3094+
],
3095+
ids=["baseline", "mtp3_fp8kv_chunked"])
3096+
def test_nvfp4_multi_gpus_piecewise_cuda_graph(self, tp_size, pp_size,
3097+
ep_size, mtp_nextn,
3098+
attention_dp, max_batch_size,
3099+
moe_backend, fp8kv,
3100+
chunked_prefill):
3101+
sm_version = get_sm_version()
3102+
if moe_backend == "TRTLLM" and sm_version in (120, 121):
3103+
pytest.skip(f"{moe_backend} backend does not support SM 120 or 121")
3104+
3105+
moe_config = MoeConfig(backend=moe_backend, max_num_tokens=16384)
3106+
kv_cache_config = KvCacheConfig(free_gpu_memory_fraction=0.7)
3107+
if fp8kv:
3108+
kv_cache_config.dtype = "fp8"
3109+
kv_cache_config.enable_block_reuse = True
3110+
3111+
cuda_graph_config = CudaGraphConfig(
3112+
enable_padding=True,
3113+
max_batch_size=max_batch_size,
3114+
)
3115+
torch_compile_config = TorchCompileConfig(
3116+
enable_piecewise_cuda_graph=True,
3117+
capture_num_tokens=[2048, 8192],
3118+
max_num_streams=3,
3119+
)
3120+
pytorch_config = dict(
3121+
disable_overlap_scheduler=False,
3122+
cuda_graph_config=cuda_graph_config,
3123+
moe_config=moe_config,
3124+
torch_compile_config=torch_compile_config,
3125+
)
3126+
3127+
mtp_config = None
3128+
if mtp_nextn > 0:
3129+
mtp_config = MTPDecodingConfig(num_nextn_predict_layers=mtp_nextn)
3130+
3131+
llm_kwargs = dict(
3132+
max_batch_size=max_batch_size,
3133+
tensor_parallel_size=tp_size,
3134+
pipeline_parallel_size=pp_size,
3135+
moe_expert_parallel_size=ep_size,
3136+
kv_cache_config=kv_cache_config,
3137+
enable_attention_dp=attention_dp,
3138+
speculative_config=mtp_config,
3139+
)
3140+
if chunked_prefill:
3141+
llm_kwargs.update(
3142+
enable_chunked_prefill=True,
3143+
max_num_tokens=8192,
3144+
)
3145+
3146+
with LLM(f"{llm_models_root()}/DeepSeek-V3.2-Exp-FP4-v2",
3147+
**pytorch_config, **llm_kwargs) as llm:
3148+
3149+
task = MMLU(self.MODEL_NAME)
3150+
task.evaluate(llm)
3151+
task = GSM8K(self.MODEL_NAME)
3152+
task.evaluate(llm)
3153+
30873154
@pytest.mark.skip_less_mpi_world_size(8)
30883155
@skip_pre_blackwell
30893156
@pytest.mark.parametrize(

tests/integration/test_lists/qa/llm_function_core.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -151,6 +151,8 @@ accuracy/test_llm_api_pytorch.py::TestDeepSeekV32::test_nvfp4_multi_gpus[baselin
151151
accuracy/test_llm_api_pytorch.py::TestDeepSeekV32::test_nvfp4_multi_gpus_chunked_prefill[baseline_fp8kv]
152152
accuracy/test_llm_api_pytorch.py::TestDeepSeekV32::test_nvfp4_multi_gpus_chunked_prefill[latency]
153153
accuracy/test_llm_api_pytorch.py::TestDeepSeekV32::test_nvfp4_multi_gpus_chunked_prefill[latency_qsplit]
154+
accuracy/test_llm_api_pytorch.py::TestDeepSeekV32::test_nvfp4_multi_gpus_piecewise_cuda_graph[baseline]
155+
accuracy/test_llm_api_pytorch.py::TestDeepSeekV32::test_nvfp4_multi_gpus_piecewise_cuda_graph[mtp3_fp8kv_chunked]
154156
accuracy/test_llm_api_pytorch.py::TestQwen2_7BInstruct::test_auto_dtype
155157
accuracy/test_llm_api_pytorch.py::TestQwen3_4B::test_eagle3
156158
accuracy/test_llm_api_pytorch.py::TestQwen3_8B::test_eagle3[eagle3_one_model=True-enable_chunked_prefill=False-enable_max_concurrency=False-enable_draft_len_schedule=True]

tests/integration/test_lists/qa/llm_function_core_sanity.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,8 @@ accuracy/test_llm_api_pytorch.py::TestDeepSeekV32::test_nvfp4_multi_gpus[disable
3030
accuracy/test_llm_api_pytorch.py::TestDeepSeekV32::test_nvfp4_multi_gpus[latency]
3131
accuracy/test_llm_api_pytorch.py::TestDeepSeekV32::test_nvfp4_multi_gpus_chunked_prefill[baseline_fp8kv]
3232
accuracy/test_llm_api_pytorch.py::TestDeepSeekV32::test_nvfp4_multi_gpus_chunked_prefill[latency]
33+
accuracy/test_llm_api_pytorch.py::TestDeepSeekV32::test_nvfp4_multi_gpus_piecewise_cuda_graph[baseline]
34+
accuracy/test_llm_api_pytorch.py::TestDeepSeekV32::test_nvfp4_multi_gpus_piecewise_cuda_graph[mtp3_fp8kv_chunked]
3335
accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_bfloat16_4gpus_online_eplb[mtp_nextn=0-moe_backend=WIDEEP]
3436
accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_bfloat16_4gpus_online_eplb[mtp_nextn=2-moe_backend=WIDEEP]
3537
accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_fp8_block_scales_4gpus_static_eplb

tests/integration/test_lists/test-db/l0_dgx_b200.yml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -129,6 +129,8 @@ l0_dgx_b200:
129129
- accuracy/test_llm_api_pytorch.py::TestDeepSeekV32::test_fp8_blockscale[baseline_mtp1] TIMEOUT (60)
130130
- accuracy/test_llm_api_pytorch.py::TestDeepSeekV32::test_nvfp4_multi_gpus[baseline] TIMEOUT (60)
131131
- accuracy/test_llm_api_pytorch.py::TestDeepSeekV32::test_nvfp4_multi_gpus[baseline_mtp1] TIMEOUT (60)
132+
- accuracy/test_llm_api_pytorch.py::TestDeepSeekV32::test_nvfp4_multi_gpus_piecewise_cuda_graph[baseline] TIMEOUT (60)
133+
- accuracy/test_llm_api_pytorch.py::TestDeepSeekV32::test_nvfp4_multi_gpus_piecewise_cuda_graph[mtp3_fp8kv_chunked] TIMEOUT (60)
132134
- accuracy/test_disaggregated_serving.py::TestDeepSeekV32Exp::test_auto_dtype[False] TIMEOUT (60)
133135
- accuracy/test_llm_api_pytorch.py::TestKimiK25::test_nvfp4[tp8] TIMEOUT (60)
134136
- accuracy/test_llm_api_pytorch.py::TestNemotronV3Super::test_nvfp4_8gpus_mtp TIMEOUT (60)

0 commit comments

Comments
 (0)