Skip to content

Commit 1aa232a

Browse files
authored
[TRTLLM-12807][feat] Add multiple FMHA library support to TRTLLM attention backend (#15204)
Signed-off-by: Yuxian Qiu <142763828+yuxianq@users.noreply.github.com>
1 parent 5c8a359 commit 1aa232a

9 files changed

Lines changed: 935 additions & 587 deletions

File tree

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-License-Identifier: Apache-2.0
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
from .fallback import FallbackFmha
17+
from .flashinfer_trtllm_gen import FlashInferTrtllmGenFmha
18+
from .interface import Fmha
19+
from .phased import FmhaParams, PhasedFmha
20+
from .registry import DEFAULT_FMHA_LIBS, FMHA_LIBS, FmhaCls, get_enabled_fmha_lib_classes
21+
22+
__all__ = [
23+
"DEFAULT_FMHA_LIBS",
24+
"FMHA_LIBS",
25+
"FallbackFmha",
26+
"FlashInferTrtllmGenFmha",
27+
"Fmha",
28+
"FmhaCls",
29+
"FmhaParams",
30+
"PhasedFmha",
31+
"get_enabled_fmha_lib_classes",
32+
]
Lines changed: 187 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,187 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-License-Identifier: Apache-2.0
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
from typing import TYPE_CHECKING, Optional
17+
18+
import torch
19+
20+
from tensorrt_llm._torch.attention_backend.interface import AttentionForwardArgs
21+
from tensorrt_llm._torch.attention_backend.sparse.skip_softmax import (
22+
SkipSoftmaxKernelParams,
23+
SkipSoftmaxParams,
24+
)
25+
from tensorrt_llm.bindings.internal import thop
26+
27+
from .interface import Fmha
28+
29+
if TYPE_CHECKING:
30+
from tensorrt_llm._torch.attention_backend.trtllm import TrtllmAttentionMetadata
31+
32+
33+
# ``AttentionForwardArgs`` fields that this backend does not consume.
34+
# Sync test (test_attention_op_sync.py) requires every other field to map to a
35+
# kwarg name, a @property on the dataclass, or a field that some @property
36+
# transitively reads; entries here are exempt.
37+
_THOP_EXCLUDED_FIELDS: frozenset = frozenset(
38+
{
39+
"topk_indices", # DSA-only
40+
"attention_mask_data", # custom-mask code path
41+
"out_scale_sf", # promoted into ``out_scale`` in ``TrtllmAttention.forward`` for NVFP4 path
42+
}
43+
)
44+
45+
# ``thop.attention`` kwargs hard-wired to a literal at the call site (no
46+
# rich object owns them). Sync test enforces both the kwarg name and the
47+
# literal value.
48+
_THOP_LITERALS: dict = {}
49+
50+
51+
class FallbackFmha(Fmha):
52+
"""Fallback FMHA implementation using the fused TRT-LLM thop attention op."""
53+
54+
def forward(
55+
self,
56+
q: torch.Tensor,
57+
k: Optional[torch.Tensor],
58+
v: Optional[torch.Tensor],
59+
metadata: "TrtllmAttentionMetadata",
60+
forward_args: AttentionForwardArgs,
61+
) -> None:
62+
attn = self.attn
63+
sparse_params = attn.sparse_params
64+
skip_softmax_kernel_params = (
65+
sparse_params.scheduler.get_kernel_params(timestep=forward_args.timestep)
66+
if isinstance(sparse_params, SkipSoftmaxParams)
67+
else SkipSoftmaxKernelParams()
68+
)
69+
70+
# Every kwarg sources from ``attn`` / ``metadata`` / ``forward_args``
71+
# (with ``forward_args.sparse_prediction`` for sparse-attn inputs),
72+
# ``skip_softmax_kernel_params``, or a literal allowlisted in
73+
# ``_THOP_LITERALS``. ``test_attention_op_sync.py`` enforces this
74+
# statically.
75+
thop.attention(
76+
q=q,
77+
k=k,
78+
v=v,
79+
output=forward_args.output,
80+
output_sf=forward_args.output_sf,
81+
workspace_=metadata.effective_workspace,
82+
# --- Per-step batch state (TrtllmAttentionMetadata) ---
83+
sequence_length=metadata.kv_lens_cuda_runtime,
84+
host_past_key_value_lengths=metadata.kv_lens_runtime,
85+
host_total_kv_lens=metadata.host_total_kv_lens,
86+
context_lengths=metadata.prompt_lens_cuda_runtime,
87+
host_context_lengths=metadata.prompt_lens_cpu_runtime,
88+
host_request_types=metadata.host_request_types_runtime,
89+
max_context_q_len_override=metadata.max_context_q_len_override,
90+
kv_cache_block_offsets=metadata.kv_cache_block_offsets,
91+
host_kv_cache_pool_pointers=metadata.host_kv_cache_pool_pointers,
92+
host_kv_cache_pool_mapping=metadata.host_kv_cache_pool_mapping,
93+
cache_indirection=metadata.cache_indirection,
94+
block_ids_per_seq=metadata.block_ids_per_seq,
95+
tokens_per_block=metadata.tokens_per_block,
96+
max_num_requests=metadata.max_num_requests,
97+
beam_width=metadata.effective_beam_width,
98+
use_paged_context_fmha=metadata.use_paged_context_fmha,
99+
helix_position_offsets=metadata.helix_position_offsets,
100+
helix_is_inactive_rank=metadata.helix_is_inactive_rank,
101+
is_spec_decoding_enabled=metadata.is_spec_decoding_enabled,
102+
use_spec_decoding=metadata.use_spec_decoding,
103+
is_spec_dec_tree=metadata.is_spec_dec_tree,
104+
spec_decoding_generation_lengths=metadata.spec_decoding_generation_lengths,
105+
spec_decoding_position_offsets_for_cpp=metadata.spec_decoding_position_offsets_for_cpp,
106+
spec_decoding_packed_mask=metadata.spec_decoding_packed_mask,
107+
spec_decoding_bl_tree_mask_offset=metadata.spec_decoding_bl_tree_mask_offset,
108+
spec_decoding_bl_tree_mask=metadata.spec_decoding_bl_tree_mask,
109+
spec_decoding_target_max_draft_tokens=metadata.max_total_draft_tokens,
110+
spec_bl_tree_first_sparse_mask_offset_kv=metadata.spec_bl_tree_first_sparse_mask_offset_kv,
111+
num_sparse_topk=metadata.num_sparse_topk,
112+
flash_mla_tile_scheduler_metadata=metadata.flash_mla_tile_scheduler_metadata,
113+
flash_mla_num_splits=metadata.flash_mla_num_splits,
114+
num_contexts=metadata.num_contexts,
115+
num_ctx_tokens=metadata.num_ctx_tokens,
116+
max_context_length=metadata.max_context_length,
117+
max_seq_len=metadata.max_seq_len,
118+
trtllm_gen_jit_warmup=metadata.trtllm_gen_jit_warmup,
119+
is_cross=metadata.is_cross,
120+
# --- Per-call (AttentionForwardArgs) ---
121+
out_scale=forward_args.out_scale,
122+
kv_scale_orig_quant=forward_args.kv_scale_orig_quant,
123+
kv_scale_quant_orig=forward_args.kv_scale_quant_orig,
124+
latent_cache=forward_args.latent_cache,
125+
q_pe=forward_args.q_pe,
126+
attention_sinks=forward_args.attention_sinks,
127+
mask_type=forward_args.mask_type,
128+
attention_input_type=int(forward_args.attention_input_type),
129+
attention_window_size=forward_args.attention_window_size,
130+
chunked_prefill_buffer_batch_size=forward_args.chunked_prefill_buffer_batch_size,
131+
mrope_rotary_cos_sin=forward_args.mrope_rotary_cos_sin,
132+
mrope_position_deltas=forward_args.mrope_position_deltas,
133+
softmax_stats_tensor=forward_args.softmax_stats_tensor,
134+
cu_q_seqlens=forward_args.cu_q_seqlens,
135+
cu_kv_seqlens=forward_args.cu_kv_seqlens,
136+
fmha_scheduler_counter=forward_args.fmha_scheduler_counter,
137+
mla_bmm1_scale=forward_args.mla_bmm1_scale,
138+
mla_bmm2_scale=forward_args.mla_bmm2_scale,
139+
quant_q_buffer=forward_args.quant_q_buffer,
140+
sage_attn_num_elts_per_blk_q=forward_args.sage_attn_num_elts_per_blk_q,
141+
sage_attn_num_elts_per_blk_k=forward_args.sage_attn_num_elts_per_blk_k,
142+
sage_attn_num_elts_per_blk_v=forward_args.sage_attn_num_elts_per_blk_v,
143+
sage_attn_qk_int8=forward_args.sage_attn_qk_int8,
144+
is_fused_qkv=forward_args.is_fused_qkv,
145+
update_kv_cache=forward_args.update_kv_cache,
146+
cross_kv=forward_args.cross_kv,
147+
relative_attention_bias=forward_args.relative_attention_bias,
148+
relative_attention_max_distance=forward_args.relative_attention_max_distance,
149+
# --- Module config (TrtllmAttention) ---
150+
rotary_inv_freq=attn.rotary_inv_freq,
151+
rotary_cos_sin=attn.rotary_cos_sin,
152+
predicted_tokens_per_seq=attn.predicted_tokens_per_seq,
153+
local_layer_idx=attn.local_layer_idx,
154+
num_heads=attn.num_heads,
155+
num_kv_heads=attn.num_kv_heads,
156+
head_size=attn.head_dim,
157+
quant_mode=attn.quant_mode,
158+
q_scaling=attn.q_scaling,
159+
position_embedding_type=attn.position_embedding_type,
160+
rope_dim=attn.rope_dim,
161+
rope_base=attn.rope_base,
162+
rope_scale_type=attn.rope_scale_type,
163+
rope_scale=attn.rope_scale,
164+
rope_short_m_scale=attn.rope_short_m_scale,
165+
rope_long_m_scale=attn.rope_long_m_scale,
166+
rope_max_positions=attn.rope_max_positions,
167+
rope_original_max_positions=attn.rope_original_max_positions,
168+
is_mla_enable=attn.is_mla_enable,
169+
q_lora_rank=attn.q_lora_rank,
170+
kv_lora_rank=attn.kv_lora_rank,
171+
qk_nope_head_dim=attn.qk_nope_head_dim,
172+
qk_rope_head_dim=attn.qk_rope_head_dim,
173+
v_head_dim=attn.v_head_dim,
174+
rope_append=attn.rope_append,
175+
attention_chunk_size=attn.attention_chunk_size,
176+
skip_softmax_threshold_scale_factor_prefill=skip_softmax_kernel_params.threshold_scale_factor_prefill,
177+
skip_softmax_threshold_scale_factor_decode=skip_softmax_kernel_params.threshold_scale_factor_decode,
178+
skip_softmax_stat=attn.skip_softmax_stat,
179+
# --- Sparse-specific (AttentionForwardArgs.sparse_prediction) ---
180+
sparse_kv_indices=forward_args.sparse_prediction.sparse_kv_indices,
181+
sparse_kv_offsets=forward_args.sparse_prediction.sparse_kv_offsets,
182+
sparse_attn_indices=forward_args.sparse_prediction.sparse_attn_indices,
183+
sparse_attn_offsets=forward_args.sparse_prediction.sparse_attn_offsets,
184+
sparse_attn_indices_block_size=forward_args.sparse_prediction.sparse_attn_indices_block_size,
185+
sparse_mla_topk_lens=forward_args.sparse_prediction.sparse_mla_topk_lens,
186+
compressed_kv_cache_pool_ptr=forward_args.sparse_prediction.compressed_kv_cache_pool_ptr,
187+
)

0 commit comments

Comments
 (0)