Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 9 additions & 5 deletions examples/auto_deploy/model_registry/configs/gemma4_dense.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -6,20 +6,24 @@
# paged KV cache, CUDA-graph-compatible, FlashDecoding for decode.
model_factory: Gemma4ForConditionalGeneration
tokenizer: google/gemma-4-31B-it
world_size: 4
attn_backend: triton
compile_backend: torch-cudagraph
cuda_graph_config:
batch_sizes: [1, 2, 4, 8, 16, 32, 64, 128, 256, 512]
max_num_tokens: 8192
max_batch_size: 512
max_seq_len: 8192
batch_sizes: [1, 2, 3, 4, 5, 6, 7, 8]
max_num_tokens: 16000
max_batch_size: 8
max_seq_len: 16000
enable_chunked_prefill: true
kv_cache_config:
enable_block_reuse: false
free_gpu_memory_fraction: 0.8
free_gpu_memory_fraction: 0.6
transforms:
compile_model:
piecewise_enabled: true
piecewise_num_tokens: [1, 2, 4, 8, 16, 32, 64, 128, 256,
512, 768, 1024, 1280, 1536, 1792, 2048,
2560, 3072, 4096, 5120, 6144, 7168, 8192]
mlir_elementwise_fusion:
enabled: true
gather_logits_before_lm_head:
Expand Down
263 changes: 263 additions & 0 deletions tensorrt_llm/_torch/attention_backend/flashinfer.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,116 @@ class FlashInferAttentionMetadata(AttentionMetadata):
_mla_kv_len_arr_buf: Optional[torch.Tensor] = field(init=False,
default=None)

# One-engine speculative decoding (MTP/Eagle3). Mirror TrtllmAttentionMetadata
# so update_spec_dec_param() can allocate/fill buffers on Hopper, where TRTLLM
# MMHA cannot serve Gemma4 global layers (head_dim=512).
is_spec_decoding_enabled: bool = False
use_spec_decoding: bool = False
is_spec_dec_tree: bool = False
is_spec_dec_dynamic_tree: bool = False
max_total_draft_tokens: Optional[int] = None
spec_decoding_position_offsets: Optional[torch.Tensor] = None
spec_decoding_position_offsets_cpp: Optional[torch.Tensor] = None
position_offsets_stride: int = 0
spec_decoding_packed_mask: Optional[torch.Tensor] = None
spec_decoding_generation_lengths: Optional[torch.Tensor] = None
spec_decoding_bl_tree_mask_offset: Optional[torch.Tensor] = None
spec_decoding_bl_tree_mask: Optional[torch.Tensor] = None
spec_bl_tree_first_sparse_mask_offset_kv: Optional[torch.Tensor] = None
# MTP/Eagle3 draft loop updates these in place (mirrors TrtllmAttentionMetadata).
host_request_types: torch.Tensor = field(init=False)
kv_lens_cuda: torch.Tensor = field(init=False)

def is_sm_version_trtllm_gen_kernel(self, sm: int) -> bool:
from .trtllm import TrtllmAttention
return TrtllmAttention.is_sm_version_trtllm_gen_kernel(sm)

def update_spec_dec_param(self, *args, **kwargs) -> None:
from .trtllm import TrtllmAttentionMetadata
return TrtllmAttentionMetadata.update_spec_dec_param(self, *args,
**kwargs)

def update_position_offsets_for_cpp(self, query_len: int) -> None:
from .trtllm import TrtllmAttentionMetadata
return TrtllmAttentionMetadata.update_position_offsets_for_cpp(
self, query_len)

def generate_spec_decoding_generation_length(self,
runtime_draft_len: int) -> None:
from .trtllm import TrtllmAttentionMetadata
return TrtllmAttentionMetadata.generate_spec_decoding_generation_length(
self, runtime_draft_len)

def update_for_spec_dec(self) -> None:
"""Refresh paged-KV views after MTP draft-loop in-place kv_lens_cuda edits."""
n = self.num_contexts + self.num_generations
if n == 0:
return
kv_lens = self.kv_lens_cuda[:n]
self._cached_token_lens[:n].copy_(
kv_lens - self.seq_lens_kv_cuda[:n])
num_blocks = ((kv_lens + self.page_size - 1) // self.page_size)
if getattr(self, "num_blocks", None) is not None:
for i in range(n):
self.num_blocks[i] = int(num_blocks[i].item())
self.num_context_blocks = sum(self.num_blocks[:self.num_contexts])
self.num_generation_blocks = sum(
self.num_blocks[self.num_contexts:])
paged_kv_last_page_len = kv_lens - (num_blocks - 1) * self.page_size
self._paged_kv_last_page_len[:n].copy_(paged_kv_last_page_len)
if self.num_contexts == 0 and self.num_generations > 0:
paged_kv_indptr_decode = torch.cumsum(
torch.tensor([0] + self.num_blocks[self.num_contexts:],
dtype=torch.int32),
dim=0,
)
self.paged_kv_indptr_decode[:paged_kv_indptr_decode.size(
0)].copy_(paged_kv_indptr_decode, non_blocking=True)
self.paged_kv_indptr = self.paged_kv_indptr_decode[:
paged_kv_indptr_decode
.size(0)]

def _triton_physical_kv_lens(
self,
start: int,
end: int,
use_spec_dec: bool,
) -> torch.Tensor:
"""Map TRTLLM spec-dec KV lengths to physical paged-KV lengths for Triton."""
kv_lens = self.kv_lens_cuda[start:end]
if (use_spec_dec
and self.spec_decoding_generation_lengths is not None):
kv_lens = kv_lens - self.spec_decoding_generation_lengths[
start:end]
return kv_lens

def _triton_gen_paged_params(
self,
gen_start: int,
gen_end: int,
use_spec_dec: bool,
):
"""Rebuild decode page table views using physical KV lengths."""
kv_lens = self._triton_physical_kv_lens(gen_start, gen_end,
use_spec_dec)
page_size = self.page_size
num_blocks = ((kv_lens + page_size - 1) // page_size).to(torch.int32)
last_page_len = kv_lens - (num_blocks - 1) * page_size
indptr = torch.zeros(gen_end - gen_start + 1,
dtype=torch.int32,
device=kv_lens.device)
indptr[1:] = torch.cumsum(num_blocks, dim=0)
base_indptr = self.paged_kv_indptr_decode[gen_start:gen_end + 1]
indices_parts = []
for i in range(gen_end - gen_start):
block_count = int(num_blocks[i].item())
indices_parts.append(
self._paged_kv_indices[base_indptr[i]:base_indptr[i] +
block_count])
kv_indices = torch.cat(indices_parts) if indices_parts else (
self._paged_kv_indices[:0])
return kv_lens, last_page_len, indptr, kv_indices

def needs_plan(self, plan_params: PlanParams) -> bool:
if plan_params not in self._plan_params_to_wrappers:
return True
Expand Down Expand Up @@ -503,6 +613,13 @@ def _post_init_with_buffers(self, buffers) -> None:
self._cached_token_lens = torch.empty((self.max_num_requests, ),
dtype=torch.int,
device='cuda')
self.kv_lens_cuda = torch.empty((self.max_num_requests, ),
dtype=torch.int,
device='cuda')
self.host_request_types = torch.empty((self.max_num_requests, ),
dtype=torch.int,
device='cpu',
pin_memory=prefer_pinned())
self._batch_indices = torch.empty((self.max_num_tokens, ),
dtype=torch.int,
device='cuda')
Expand Down Expand Up @@ -780,6 +897,10 @@ def prepare(self) -> None:

# number of tokens needed in the kv cache for each sequence after the next pass
kv_lens = self.cached_token_lens + self.seq_lens_kv_cuda
n_seqs = self.num_contexts + self.num_generations
self.kv_lens_cuda[:n_seqs].copy_(kv_lens[:n_seqs], non_blocking=True)
self.host_request_types[:self.num_contexts].fill_(0)
self.host_request_types[self.num_contexts:n_seqs].fill_(1)

# start and end indices of each sequence in the ragged key and value
# for self attention it's the same as qo_indptr so avoid computing twice.
Expand Down Expand Up @@ -1048,6 +1169,9 @@ def _plan_with_params(self,
if not self.needs_plan(plan_params):
return plan_params

if flashinfer_backend == "triton":
return plan_params

if self.is_cuda_graph and torch.cuda.is_current_stream_capturing():
raise ValueError(
"Cannot plan() for flashinfer kernels while stream is capturing. "
Expand Down Expand Up @@ -1600,6 +1724,121 @@ def _mla_forward_paged_context(
out=output[:num_tokens].view(-1, self.num_heads,
self.kv_lora_rank))

def _forward_triton_paged(
self,
q: torch.Tensor,
k: Optional[torch.Tensor],
v: Optional[torch.Tensor],
kv_cache: torch.Tensor,
metadata: FlashInferAttentionMetadata,
output: torch.Tensor,
attention_mask_data: Optional[torch.Tensor],
attention_window_size: Optional[int],
num_contexts: int,
num_generations: int,
num_ctx_tokens: int,
) -> None:
"""Paged attention via Triton kernels (Hopper and pre-Blackwell fallback)."""
from tensorrt_llm._torch.auto_deploy.custom_ops.attention.triton_attention import (
triton_context,
triton_decode,
)

from .triton_prefill import triton_prefill_with_custom_mask

sm_scale = 1 / (math.sqrt(self.head_dim) * self.q_scaling)
sliding_window = attention_window_size
window_left = attention_window_size if attention_window_size is not None else -1
if metadata.is_cuda_graph:
raise RuntimeError(
"Gemma4 Triton attention (Hopper fallback) is incompatible with "
"decode CUDA graphs. Set cuda_graph_config: null in your serve "
"config, or use --backend _autodeploy with gemma4_dense.yaml."
)
out_view = output.view(-1, self.num_heads, self.head_dim)

if num_contexts > 0:
q_ctx = q[:num_ctx_tokens].view(-1, self.num_heads, self.head_dim)
kv_lens = metadata.kv_lens_cuda[:num_contexts]
if attention_mask_data is not None and k is not None:
k_ctx = k[:num_ctx_tokens].view(-1, self.num_kv_heads, self.head_dim)
v_ctx = v[:num_ctx_tokens].view(-1, self.num_kv_heads, self.head_dim)
triton_prefill_with_custom_mask(
q=q_ctx,
k=k_ctx,
v=v_ctx,
output=out_view[:num_ctx_tokens],
qo_indptr=metadata.qo_indptr[:num_contexts + 1],
kv_cache=kv_cache,
prefix_lens=metadata.cached_token_lens[:num_contexts].clone(),
page_table_indptr=metadata.paged_kv_indptr_prefill[:num_contexts + 1],
page_table_indices=metadata._paged_kv_indices[:metadata.num_context_blocks],
page_size=metadata.page_size,
custom_mask=attention_mask_data,
sm_scale=sm_scale,
window_left=window_left,
)
else:
triton_context(
q=q_ctx,
kv_cache=kv_cache,
qo_indptr=metadata.qo_indptr[:num_contexts + 1],
kv_indptr=metadata.paged_kv_indptr_prefill[:num_contexts + 1],
kv_indices=metadata._paged_kv_indices[:metadata.num_context_blocks],
kv_last_page_len=metadata.paged_kv_last_page_len[:num_contexts],
seq_len_with_cache=kv_lens,
sm_scale=sm_scale,
sliding_window=sliding_window,
out=out_view[:num_ctx_tokens],
)

# Generation phase. Do not route head_dim=512 global-attention layers
# through FlashInfer fa2 on SM90 (see flashinfer PR #3652).
if num_generations > 0:
num_gen_tokens = q.shape[0] - num_ctx_tokens
gen_start = num_contexts
gen_end = num_contexts + num_generations
gen_kv_indices = metadata._paged_kv_indices[
metadata.num_context_blocks:
metadata.num_context_blocks + metadata.num_generation_blocks
]
# MTP/Eagle3 verification sends multiple Q tokens per generation
# sequence; triton_decode only supports one Q token per sequence.
use_spec_dec = getattr(metadata, 'use_spec_decoding', False)
multi_token_gen = num_gen_tokens > num_generations
if use_spec_dec or multi_token_gen:
q_gen = q[num_ctx_tokens:].view(-1, self.num_heads, self.head_dim)
gen_qo_indptr = metadata.qo_indptr[gen_start:gen_end + 1].clone()
gen_qo_indptr -= gen_qo_indptr[0].item()
kv_lens, last_page_len, kv_indptr, gen_kv_indices = (
metadata._triton_gen_paged_params(gen_start, gen_end,
use_spec_dec))
triton_context(
q=q_gen,
kv_cache=kv_cache,
qo_indptr=gen_qo_indptr,
kv_indptr=kv_indptr,
kv_indices=gen_kv_indices,
kv_last_page_len=last_page_len,
seq_len_with_cache=kv_lens,
sm_scale=sm_scale,
sliding_window=sliding_window,
out=out_view[num_ctx_tokens:],
)
else:
q_dec = q[num_ctx_tokens:num_ctx_tokens + num_generations].view(
num_generations, self.num_heads, self.head_dim)
triton_decode(
q=q_dec,
kv_cache=kv_cache,
kv_indices=gen_kv_indices,
kv_indptr=metadata.paged_kv_indptr_decode[:num_generations + 1],
kv_last_page_len=metadata._paged_kv_last_page_len[gen_start:gen_end],
sm_scale=sm_scale,
sliding_window=sliding_window,
out=out_view[num_ctx_tokens:num_ctx_tokens + num_generations],
)

def forward_impl(
self,
q: torch.Tensor,
Expand Down Expand Up @@ -1748,6 +1987,30 @@ def forward_impl(
num_contexts = metadata.num_contexts
num_generations = metadata.num_generations
num_ctx_tokens = metadata.num_ctx_tokens
use_spec_dec = getattr(metadata, 'use_spec_decoding', False)
num_gen_tokens = q.shape[0] - num_ctx_tokens
multi_token_gen = num_generations > 0 and num_gen_tokens > num_generations

# Hopper Gemma4 routes head_dim=512 layers through Triton. Sliding
# layers (head_dim=256) use fa2 for single-token decode, but MTP/Eagle3
# verification sends multiple Q tokens per generation sequence; fa2
# batch decode rejects that, so fall back to Triton multi-token context.
if (self.flashinfer_backend == "triton" or use_spec_dec
or multi_token_gen):
self._forward_triton_paged(
q=q,
k=k,
v=v,
kv_cache=kv_cache,
metadata=metadata,
output=output,
attention_mask_data=attention_mask_data,
attention_window_size=attention_window_size,
num_contexts=num_contexts,
num_generations=num_generations,
num_ctx_tokens=num_ctx_tokens,
)
return

def prefill_forward(plan_params: PlanParams, out: torch.Tensor):
wrapper = metadata.get_prefill_wrapper(plan_params)
Expand Down
Loading
Loading