Skip to content
Open
236 changes: 123 additions & 113 deletions tensorrt_llm/_torch/attention_backend/flashinfer.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,23 @@
import functools
import math
import os
import sys
import weakref
from dataclasses import dataclass, field
from typing import Any, Dict, Literal, NewType, Optional, TypeAlias, cast

if sys.version_info[:2] >= (3, 12):
from typing import override
else:
from typing_extensions import override

import flashinfer
import torch
from flashinfer.jit.core import check_cuda_arch
from typing_extensions import Self

from tensorrt_llm._torch.pyexecutor.sampling_utils import torch_multi_arange
from tensorrt_llm._utils import nvtx_range
from tensorrt_llm.functional import AttentionMaskType
from tensorrt_llm.logger import logger
from tensorrt_llm.models.modeling_utils import QuantConfig
Expand Down Expand Up @@ -157,6 +164,9 @@ class FlashInferAttentionMetadata(AttentionMetadata):
_mla_kv_len_arr_buf: Optional[torch.Tensor] = field(init=False,
default=None)

_multi_item_params: Optional[FlashInferMultiItemParams] = field(
init=False, default=None)

def needs_plan(self, plan_params: PlanParams) -> bool:
if plan_params not in self._plan_params_to_wrappers:
return True
Expand Down Expand Up @@ -595,6 +605,7 @@ def _post_init_with_buffers(self, buffers) -> None:
self._mla_ragged_planned = False
self._mla_context_planned = False
self._mla_decode_planned = False
self._multi_item_params = None

def create_cuda_graph_metadata(self,
max_batch_size: int,
Expand Down Expand Up @@ -722,6 +733,89 @@ def _plan_ragged_no_kv(
**plan_kwargs,
)

def _process_multi_item_part_lens(
self,
multi_item_part_lens: list[list[int]],
*,
device: torch.device,
) -> FlashInferMultiItemParams:
if self.num_generations > 0:
raise ValueError(
"\"multi_item_part_lens\" not supported for generation requests."
)
if len(multi_item_part_lens) != self.num_contexts:
raise ValueError(
"\"multi_item_part_lens\" needs to be provided for all requests."
)

prefix_len_ptr = torch.tensor(
[req_part_lens[0] for req_part_lens in multi_item_part_lens],
pin_memory=prefer_pinned(),
dtype=torch.uint32,
).to(device=device, non_blocking=True)
token_pos_in_items_raw_lens = [ # 'raw' lengths before padding
sum(req_part_lens[1:]) + len(req_part_lens)
for req_part_lens in multi_item_part_lens
]
token_pos_in_items_len = max(token_pos_in_items_raw_lens)
max_item_len_ptr = torch.tensor(
[max(req_part_lens[1:]) for req_part_lens in multi_item_part_lens],
pin_memory=prefer_pinned(),
dtype=torch.uint16,
).to(device=device, non_blocking=True)

# token_pos_in_items_ptr is obtained by concatenating range(item_len + 1) for each item in
# every request, followed by [0] (final delimiter) which is fused with padding for simplicity.
range_ends = torch.tensor(
[
item_len + 1
for req_part_lens, token_pos_in_items_raw_len in zip(
multi_item_part_lens, token_pos_in_items_raw_lens)
for item_len in (
req_part_lens[1:] +
[token_pos_in_items_len - token_pos_in_items_raw_len])
],
pin_memory=prefer_pinned(),
dtype=torch.int32,
).to(device=device, non_blocking=True)
token_pos_in_items_ptr = torch_multi_arange(
range_ends,
output_length=(token_pos_in_items_len * len(multi_item_part_lens)),
)
# next, mask out the padding
mask_entries = torch.arange(2, dtype=torch.uint8).to(
device=device,
non_blocking=True,
dtype=torch.bool,
).repeat(len(multi_item_part_lens)) # NB: .expand() does not work here
mask_entry_repeats = torch.tensor(
[
repeat
for token_pos_in_items_raw_len in token_pos_in_items_raw_lens
for repeat in [
token_pos_in_items_raw_len,
token_pos_in_items_len - token_pos_in_items_raw_len,
]
],
pin_memory=prefer_pinned(),
dtype=torch.int32,
).to(device=device, non_blocking=True)
padding_mask = torch.repeat_interleave(
input=mask_entries,
repeats=mask_entry_repeats,
output_size=token_pos_in_items_ptr.size(0),
)
token_pos_in_items_ptr.masked_fill_(padding_mask, 0)
token_pos_in_items_ptr = token_pos_in_items_ptr.to(dtype=torch.uint16,
non_blocking=True)

return FlashInferMultiItemParams(
prefix_len_ptr=prefix_len_ptr,
max_item_len_ptr=max_item_len_ptr,
token_pos_in_items_ptr=token_pos_in_items_ptr,
token_pos_in_items_len=token_pos_in_items_len,
)

def _clean_cached_plans(self, *, defer_plan: bool):
for plan_params in list(self._plan_params_to_wrappers.keys()):
# Generally, plan_params with non-trivial attention masking are relevant only the
Expand All @@ -740,11 +834,18 @@ def prepare(self) -> None:
if extra_attrs is None:
get_global_attrs().attention_metadata = weakref.ref(self)
# start and end indices of each sequence in the ragged query
assert self.seq_lens_cuda is not None
torch.cumsum(self.seq_lens_cuda,
dim=0,
dtype=torch.int32,
out=self._qo_indptr[1:self.seq_lens_cuda.size(0) + 1])

if self.multi_item_part_lens is not None:
self._multi_item_params = self._process_multi_item_part_lens(
self.multi_item_part_lens, device=self.seq_lens_cuda.device)
else:
self._multi_item_params = None

if self.kv_cache_manager is None:
assert self.request_ids is not None
assert self.num_generations == 0, (
Expand All @@ -761,6 +862,10 @@ def prepare(self) -> None:
self._clean_cached_plans(defer_plan=False)
return

if self._multi_item_params is not None:
raise ValueError(
"multi_item_part_lens with KV cache is not supported")

# indices of used cache blocks for each sequence
assert self.request_ids is not None
block_ids_per_seq = self.kv_cache_manager.get_batch_cache_indices(
Expand Down Expand Up @@ -1009,7 +1114,6 @@ def plan(self,
q_scaling: Optional[float] = None,
attention_window_size: Optional[int] = None,
attention_mask_data: Optional[torch.Tensor] = None,
multi_item_params: Optional[FlashInferMultiItemParams] = None,
flashinfer_backend: str = "fa2") -> PlanParams:

sm_scale = None
Expand All @@ -1027,7 +1131,7 @@ def plan(self,
if attention_window_size is not None else -1,
attention_mask_type=AttentionMaskType(attention_mask_type),
attention_mask_data=attention_mask_data,
multi_item_params=multi_item_params,
multi_item_params=self._multi_item_params,
)
return self._plan_with_params(plan_params, flashinfer_backend)

Expand Down Expand Up @@ -1214,6 +1318,11 @@ class FlashInferAttention(AttentionBackend[FlashInferAttentionMetadata]):
def support_mla(cls) -> bool:
return True

@override
@classmethod
def support_multi_item_scoring(cls) -> bool:
return True

def __init__(
self,
layer_idx: int,
Expand Down Expand Up @@ -1247,90 +1356,6 @@ def update_quant_config(self, new_quant_config: Optional[QuantConfig]):
self.has_fp8_kv_cache = self.quant_config.layer_quant_mode.has_fp8_kv_cache(
)

@staticmethod
def _process_multi_item_part_lens(
multi_item_part_lens: list[list[int]],
*,
metadata: FlashInferAttentionMetadata,
device: torch.device,
) -> FlashInferMultiItemParams:
if metadata.num_generations > 0:
raise ValueError(
"\"multi_item_part_lens\" not supported for generation requests."
)
if len(multi_item_part_lens) != metadata.num_contexts:
raise ValueError(
"\"multi_item_part_lens\" needs to be provided for all requests."
)

prefix_len_ptr = torch.tensor(
[req_part_lens[0] for req_part_lens in multi_item_part_lens],
pin_memory=prefer_pinned(),
dtype=torch.uint32,
).to(device=device, non_blocking=True)
token_pos_in_items_raw_lens = [ # 'raw' lengths before padding
sum(req_part_lens[1:]) + len(req_part_lens)
for req_part_lens in multi_item_part_lens
]
token_pos_in_items_len = max(token_pos_in_items_raw_lens)
max_item_len_ptr = torch.tensor(
[max(req_part_lens[1:]) for req_part_lens in multi_item_part_lens],
pin_memory=prefer_pinned(),
dtype=torch.uint16,
).to(device=device, non_blocking=True)

# token_pos_in_items_ptr is obtained by concatenating range(item_len + 1) for each item in
# every request, followed by [0] (final delimiter) which is fused with padding for simplicity.
range_ends = torch.tensor(
[
item_len + 1
for req_part_lens, token_pos_in_items_raw_len in zip(
multi_item_part_lens, token_pos_in_items_raw_lens)
for item_len in (
req_part_lens[1:] +
[token_pos_in_items_len - token_pos_in_items_raw_len])
],
pin_memory=prefer_pinned(),
dtype=torch.int32,
).to(device=device, non_blocking=True)
token_pos_in_items_ptr = torch_multi_arange(
range_ends,
output_length=(token_pos_in_items_len * len(multi_item_part_lens)),
)
# next, mask out the padding
mask_entries = torch.arange(2, dtype=torch.uint8).to(
device=device,
non_blocking=True,
dtype=torch.bool,
).repeat(len(multi_item_part_lens)) # NB: .expand() does not work here
mask_entry_repeats = torch.tensor(
[
repeat
for token_pos_in_items_raw_len in token_pos_in_items_raw_lens
for repeat in [
token_pos_in_items_raw_len,
token_pos_in_items_len - token_pos_in_items_raw_len,
]
],
pin_memory=prefer_pinned(),
dtype=torch.int32,
).to(device=device, non_blocking=True)
padding_mask = torch.repeat_interleave(
input=mask_entries,
repeats=mask_entry_repeats,
output_size=token_pos_in_items_ptr.size(0),
)
token_pos_in_items_ptr.masked_fill_(padding_mask, 0)
token_pos_in_items_ptr = token_pos_in_items_ptr.to(dtype=torch.uint16,
non_blocking=True)

return FlashInferMultiItemParams(
prefix_len_ptr=prefix_len_ptr,
max_item_len_ptr=max_item_len_ptr,
token_pos_in_items_ptr=token_pos_in_items_ptr,
token_pos_in_items_len=token_pos_in_items_len,
)

def mla_rope_generation(
self,
fused_q: torch.Tensor,
Expand Down Expand Up @@ -1610,7 +1635,6 @@ def forward_impl(
output: torch.Tensor,
attention_mask_data: Optional[torch.Tensor] = None,
attention_window_size: Optional[int] = None,
multi_item_part_lens: Optional[list[list[int]]] = None,
latent_cache: Optional[torch.Tensor] = None,
attention_input_type: AttentionInputType = AttentionInputType.mixed,
) -> None:
Expand Down Expand Up @@ -1648,14 +1672,6 @@ def forward_impl(
# Query
q = q.view(-1, self.num_heads, self.head_dim)

multi_item_params: FlashInferMultiItemParams | None = None
if multi_item_part_lens is not None:
multi_item_params = self._process_multi_item_part_lens(
multi_item_part_lens,
metadata=metadata,
device=q.device,
)

if metadata.kv_cache_manager is None:
assert k is not None and v is not None, (
"FlashInfer without a KV cache manager requires key/value tensors."
Expand All @@ -1670,18 +1686,18 @@ def forward_impl(
assert v.shape == (q.size(0), self.num_kv_heads * self.head_dim)
k = k.view(-1, self.num_kv_heads, self.head_dim)
v = v.view(-1, self.num_kv_heads, self.head_dim)
plan_params = metadata.plan(
self.num_heads,
self.num_kv_heads,
self.head_dim,
q_dtype=q.dtype,
kv_dtype=k.dtype,
q_scaling=self.q_scaling,
attention_window_size=attention_window_size,
attention_mask_type=attention_mask_type,
attention_mask_data=attention_mask_data,
multi_item_params=multi_item_params,
)
with nvtx_range("metadata.plan"):
plan_params = metadata.plan(
self.num_heads,
self.num_kv_heads,
self.head_dim,
q_dtype=q.dtype,
kv_dtype=k.dtype,
q_scaling=self.q_scaling,
attention_window_size=attention_window_size,
attention_mask_type=attention_mask_type,
attention_mask_data=attention_mask_data,
)
wrapper = metadata.get_ragged_prefill_wrapper(plan_params)
if isinstance(wrapper,
flashinfer.BatchPrefillWithPagedKVCacheWrapper):
Expand All @@ -1700,11 +1716,6 @@ def forward_impl(
)
return

if multi_item_part_lens is not None:
raise ValueError(
"Multi-item masking support not implemented for paged KV cache."
)

# Key and Value
kv_cache = metadata.kv_cache_manager.get_buffers(
self.layer_idx, kv_layout=metadata.kv_layout)
Expand Down Expand Up @@ -1901,6 +1912,5 @@ def forward(self,
output=output,
latent_cache=latent_cache,
attention_input_type=forward_args.attention_input_type,
multi_item_part_lens=forward_args.multi_item_part_lens,
)
return output
18 changes: 11 additions & 7 deletions tensorrt_llm/_torch/attention_backend/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,13 @@ class AttentionMetadata:
# The number of heads per kv head.
num_heads_per_kv: Optional[int] = 1

multi_item_part_lens: Optional[list[list[int]]] = None
"""Additional token layout information for multi-item scoring.

Aggregates `TokensPrompt.multi_item_part_lens` for all requests in the batch,
see `TokensPrompt` for details.
"""

def __post_init__(self) -> None:
if self.is_cross:
assert self.cross is None or self.cross is self, "Cross attention metadata should not have sub metadata"
Expand Down Expand Up @@ -835,13 +842,6 @@ class AttentionForwardArgs:
relative_attention_max_distance: int = 0
cross_kv: Optional[torch.Tensor] = None

multi_item_part_lens: Optional[list[list[int]]] = None
"""Additional token layout information for multi-item scoring.

Aggregates `TokensPrompt.multi_item_part_lens` for all requests in the batch,
see `TokensPrompt` for details.
"""

latent_cache: Optional[torch.Tensor] = None
q_pe: Optional[torch.Tensor] = None
mrope_rotary_cos_sin: Optional[torch.Tensor] = None
Expand Down Expand Up @@ -984,6 +984,10 @@ def support_fused_qkv(cls) -> bool:
def support_mla(cls) -> bool:
return False

@classmethod
def support_multi_item_scoring(cls) -> bool:
return False

def create_output(self, q: torch.Tensor, **kwargs) -> List[torch.Tensor]:
"""
Create the output tensors for the attention operation.
Expand Down
Loading
Loading