Skip to content

Commit 2708009

Browse files
authored
[TRTLLM-12982][perf] reuse multi-item scoring position_ids and params (#15413)
Signed-off-by: ixlmar <206748156+ixlmar@users.noreply.github.com>
1 parent 4a88020 commit 2708009

9 files changed

Lines changed: 199 additions & 162 deletions

File tree

tensorrt_llm/_torch/attention_backend/flashinfer.py

Lines changed: 123 additions & 113 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,23 @@
11
import functools
22
import math
33
import os
4+
import sys
45
import weakref
56
from dataclasses import dataclass, field
67
from typing import Any, Dict, Literal, NewType, Optional, TypeAlias, cast
78

9+
if sys.version_info[:2] >= (3, 12):
10+
from typing import override
11+
else:
12+
from typing_extensions import override
13+
814
import flashinfer
915
import torch
1016
from flashinfer.jit.core import check_cuda_arch
1117
from typing_extensions import Self
1218

1319
from tensorrt_llm._torch.pyexecutor.sampling_utils import torch_multi_arange
20+
from tensorrt_llm._utils import nvtx_range
1421
from tensorrt_llm.functional import AttentionMaskType
1522
from tensorrt_llm.logger import logger
1623
from tensorrt_llm.models.modeling_utils import QuantConfig
@@ -157,6 +164,9 @@ class FlashInferAttentionMetadata(AttentionMetadata):
157164
_mla_kv_len_arr_buf: Optional[torch.Tensor] = field(init=False,
158165
default=None)
159166

167+
_multi_item_params: Optional[FlashInferMultiItemParams] = field(
168+
init=False, default=None)
169+
160170
def needs_plan(self, plan_params: PlanParams) -> bool:
161171
if plan_params not in self._plan_params_to_wrappers:
162172
return True
@@ -595,6 +605,7 @@ def _post_init_with_buffers(self, buffers) -> None:
595605
self._mla_ragged_planned = False
596606
self._mla_context_planned = False
597607
self._mla_decode_planned = False
608+
self._multi_item_params = None
598609

599610
def create_cuda_graph_metadata(self,
600611
max_batch_size: int,
@@ -722,6 +733,89 @@ def _plan_ragged_no_kv(
722733
**plan_kwargs,
723734
)
724735

736+
def _process_multi_item_part_lens(
737+
self,
738+
multi_item_part_lens: list[list[int]],
739+
*,
740+
device: torch.device,
741+
) -> FlashInferMultiItemParams:
742+
if self.num_generations > 0:
743+
raise ValueError(
744+
"\"multi_item_part_lens\" not supported for generation requests."
745+
)
746+
if len(multi_item_part_lens) != self.num_contexts:
747+
raise ValueError(
748+
"\"multi_item_part_lens\" needs to be provided for all requests."
749+
)
750+
751+
prefix_len_ptr = torch.tensor(
752+
[req_part_lens[0] for req_part_lens in multi_item_part_lens],
753+
pin_memory=prefer_pinned(),
754+
dtype=torch.uint32,
755+
).to(device=device, non_blocking=True)
756+
token_pos_in_items_raw_lens = [ # 'raw' lengths before padding
757+
sum(req_part_lens[1:]) + len(req_part_lens)
758+
for req_part_lens in multi_item_part_lens
759+
]
760+
token_pos_in_items_len = max(token_pos_in_items_raw_lens)
761+
max_item_len_ptr = torch.tensor(
762+
[max(req_part_lens[1:]) for req_part_lens in multi_item_part_lens],
763+
pin_memory=prefer_pinned(),
764+
dtype=torch.uint16,
765+
).to(device=device, non_blocking=True)
766+
767+
# token_pos_in_items_ptr is obtained by concatenating range(item_len + 1) for each item in
768+
# every request, followed by [0] (final delimiter) which is fused with padding for simplicity.
769+
range_ends = torch.tensor(
770+
[
771+
item_len + 1
772+
for req_part_lens, token_pos_in_items_raw_len in zip(
773+
multi_item_part_lens, token_pos_in_items_raw_lens)
774+
for item_len in (
775+
req_part_lens[1:] +
776+
[token_pos_in_items_len - token_pos_in_items_raw_len])
777+
],
778+
pin_memory=prefer_pinned(),
779+
dtype=torch.int32,
780+
).to(device=device, non_blocking=True)
781+
token_pos_in_items_ptr = torch_multi_arange(
782+
range_ends,
783+
output_length=(token_pos_in_items_len * len(multi_item_part_lens)),
784+
)
785+
# next, mask out the padding
786+
mask_entries = torch.arange(2, dtype=torch.uint8).to(
787+
device=device,
788+
non_blocking=True,
789+
dtype=torch.bool,
790+
).repeat(len(multi_item_part_lens)) # NB: .expand() does not work here
791+
mask_entry_repeats = torch.tensor(
792+
[
793+
repeat
794+
for token_pos_in_items_raw_len in token_pos_in_items_raw_lens
795+
for repeat in [
796+
token_pos_in_items_raw_len,
797+
token_pos_in_items_len - token_pos_in_items_raw_len,
798+
]
799+
],
800+
pin_memory=prefer_pinned(),
801+
dtype=torch.int32,
802+
).to(device=device, non_blocking=True)
803+
padding_mask = torch.repeat_interleave(
804+
input=mask_entries,
805+
repeats=mask_entry_repeats,
806+
output_size=token_pos_in_items_ptr.size(0),
807+
)
808+
token_pos_in_items_ptr.masked_fill_(padding_mask, 0)
809+
token_pos_in_items_ptr = token_pos_in_items_ptr.to(dtype=torch.uint16,
810+
non_blocking=True)
811+
812+
return FlashInferMultiItemParams(
813+
prefix_len_ptr=prefix_len_ptr,
814+
max_item_len_ptr=max_item_len_ptr,
815+
token_pos_in_items_ptr=token_pos_in_items_ptr,
816+
token_pos_in_items_len=token_pos_in_items_len,
817+
)
818+
725819
def _clean_cached_plans(self, *, defer_plan: bool):
726820
for plan_params in list(self._plan_params_to_wrappers.keys()):
727821
# Generally, plan_params with non-trivial attention masking are relevant only the
@@ -740,11 +834,18 @@ def prepare(self) -> None:
740834
if extra_attrs is None:
741835
get_global_attrs().attention_metadata = weakref.ref(self)
742836
# start and end indices of each sequence in the ragged query
837+
assert self.seq_lens_cuda is not None
743838
torch.cumsum(self.seq_lens_cuda,
744839
dim=0,
745840
dtype=torch.int32,
746841
out=self._qo_indptr[1:self.seq_lens_cuda.size(0) + 1])
747842

843+
if self.multi_item_part_lens is not None:
844+
self._multi_item_params = self._process_multi_item_part_lens(
845+
self.multi_item_part_lens, device=self.seq_lens_cuda.device)
846+
else:
847+
self._multi_item_params = None
848+
748849
if self.kv_cache_manager is None:
749850
assert self.request_ids is not None
750851
assert self.num_generations == 0, (
@@ -761,6 +862,10 @@ def prepare(self) -> None:
761862
self._clean_cached_plans(defer_plan=False)
762863
return
763864

865+
if self._multi_item_params is not None:
866+
raise ValueError(
867+
"multi_item_part_lens with KV cache is not supported")
868+
764869
# indices of used cache blocks for each sequence
765870
assert self.request_ids is not None
766871
block_ids_per_seq = self.kv_cache_manager.get_batch_cache_indices(
@@ -1009,7 +1114,6 @@ def plan(self,
10091114
q_scaling: Optional[float] = None,
10101115
attention_window_size: Optional[int] = None,
10111116
attention_mask_data: Optional[torch.Tensor] = None,
1012-
multi_item_params: Optional[FlashInferMultiItemParams] = None,
10131117
flashinfer_backend: str = "fa2") -> PlanParams:
10141118

10151119
sm_scale = None
@@ -1027,7 +1131,7 @@ def plan(self,
10271131
if attention_window_size is not None else -1,
10281132
attention_mask_type=AttentionMaskType(attention_mask_type),
10291133
attention_mask_data=attention_mask_data,
1030-
multi_item_params=multi_item_params,
1134+
multi_item_params=self._multi_item_params,
10311135
)
10321136
return self._plan_with_params(plan_params, flashinfer_backend)
10331137

@@ -1214,6 +1318,11 @@ class FlashInferAttention(AttentionBackend[FlashInferAttentionMetadata]):
12141318
def support_mla(cls) -> bool:
12151319
return True
12161320

1321+
@override
1322+
@classmethod
1323+
def support_multi_item_scoring(cls) -> bool:
1324+
return True
1325+
12171326
def __init__(
12181327
self,
12191328
layer_idx: int,
@@ -1247,90 +1356,6 @@ def update_quant_config(self, new_quant_config: Optional[QuantConfig]):
12471356
self.has_fp8_kv_cache = self.quant_config.layer_quant_mode.has_fp8_kv_cache(
12481357
)
12491358

1250-
@staticmethod
1251-
def _process_multi_item_part_lens(
1252-
multi_item_part_lens: list[list[int]],
1253-
*,
1254-
metadata: FlashInferAttentionMetadata,
1255-
device: torch.device,
1256-
) -> FlashInferMultiItemParams:
1257-
if metadata.num_generations > 0:
1258-
raise ValueError(
1259-
"\"multi_item_part_lens\" not supported for generation requests."
1260-
)
1261-
if len(multi_item_part_lens) != metadata.num_contexts:
1262-
raise ValueError(
1263-
"\"multi_item_part_lens\" needs to be provided for all requests."
1264-
)
1265-
1266-
prefix_len_ptr = torch.tensor(
1267-
[req_part_lens[0] for req_part_lens in multi_item_part_lens],
1268-
pin_memory=prefer_pinned(),
1269-
dtype=torch.uint32,
1270-
).to(device=device, non_blocking=True)
1271-
token_pos_in_items_raw_lens = [ # 'raw' lengths before padding
1272-
sum(req_part_lens[1:]) + len(req_part_lens)
1273-
for req_part_lens in multi_item_part_lens
1274-
]
1275-
token_pos_in_items_len = max(token_pos_in_items_raw_lens)
1276-
max_item_len_ptr = torch.tensor(
1277-
[max(req_part_lens[1:]) for req_part_lens in multi_item_part_lens],
1278-
pin_memory=prefer_pinned(),
1279-
dtype=torch.uint16,
1280-
).to(device=device, non_blocking=True)
1281-
1282-
# token_pos_in_items_ptr is obtained by concatenating range(item_len + 1) for each item in
1283-
# every request, followed by [0] (final delimiter) which is fused with padding for simplicity.
1284-
range_ends = torch.tensor(
1285-
[
1286-
item_len + 1
1287-
for req_part_lens, token_pos_in_items_raw_len in zip(
1288-
multi_item_part_lens, token_pos_in_items_raw_lens)
1289-
for item_len in (
1290-
req_part_lens[1:] +
1291-
[token_pos_in_items_len - token_pos_in_items_raw_len])
1292-
],
1293-
pin_memory=prefer_pinned(),
1294-
dtype=torch.int32,
1295-
).to(device=device, non_blocking=True)
1296-
token_pos_in_items_ptr = torch_multi_arange(
1297-
range_ends,
1298-
output_length=(token_pos_in_items_len * len(multi_item_part_lens)),
1299-
)
1300-
# next, mask out the padding
1301-
mask_entries = torch.arange(2, dtype=torch.uint8).to(
1302-
device=device,
1303-
non_blocking=True,
1304-
dtype=torch.bool,
1305-
).repeat(len(multi_item_part_lens)) # NB: .expand() does not work here
1306-
mask_entry_repeats = torch.tensor(
1307-
[
1308-
repeat
1309-
for token_pos_in_items_raw_len in token_pos_in_items_raw_lens
1310-
for repeat in [
1311-
token_pos_in_items_raw_len,
1312-
token_pos_in_items_len - token_pos_in_items_raw_len,
1313-
]
1314-
],
1315-
pin_memory=prefer_pinned(),
1316-
dtype=torch.int32,
1317-
).to(device=device, non_blocking=True)
1318-
padding_mask = torch.repeat_interleave(
1319-
input=mask_entries,
1320-
repeats=mask_entry_repeats,
1321-
output_size=token_pos_in_items_ptr.size(0),
1322-
)
1323-
token_pos_in_items_ptr.masked_fill_(padding_mask, 0)
1324-
token_pos_in_items_ptr = token_pos_in_items_ptr.to(dtype=torch.uint16,
1325-
non_blocking=True)
1326-
1327-
return FlashInferMultiItemParams(
1328-
prefix_len_ptr=prefix_len_ptr,
1329-
max_item_len_ptr=max_item_len_ptr,
1330-
token_pos_in_items_ptr=token_pos_in_items_ptr,
1331-
token_pos_in_items_len=token_pos_in_items_len,
1332-
)
1333-
13341359
def mla_rope_generation(
13351360
self,
13361361
fused_q: torch.Tensor,
@@ -1610,7 +1635,6 @@ def forward_impl(
16101635
output: torch.Tensor,
16111636
attention_mask_data: Optional[torch.Tensor] = None,
16121637
attention_window_size: Optional[int] = None,
1613-
multi_item_part_lens: Optional[list[list[int]]] = None,
16141638
latent_cache: Optional[torch.Tensor] = None,
16151639
attention_input_type: AttentionInputType = AttentionInputType.mixed,
16161640
) -> None:
@@ -1648,14 +1672,6 @@ def forward_impl(
16481672
# Query
16491673
q = q.view(-1, self.num_heads, self.head_dim)
16501674

1651-
multi_item_params: FlashInferMultiItemParams | None = None
1652-
if multi_item_part_lens is not None:
1653-
multi_item_params = self._process_multi_item_part_lens(
1654-
multi_item_part_lens,
1655-
metadata=metadata,
1656-
device=q.device,
1657-
)
1658-
16591675
if metadata.kv_cache_manager is None:
16601676
assert k is not None and v is not None, (
16611677
"FlashInfer without a KV cache manager requires key/value tensors."
@@ -1670,18 +1686,18 @@ def forward_impl(
16701686
assert v.shape == (q.size(0), self.num_kv_heads * self.head_dim)
16711687
k = k.view(-1, self.num_kv_heads, self.head_dim)
16721688
v = v.view(-1, self.num_kv_heads, self.head_dim)
1673-
plan_params = metadata.plan(
1674-
self.num_heads,
1675-
self.num_kv_heads,
1676-
self.head_dim,
1677-
q_dtype=q.dtype,
1678-
kv_dtype=k.dtype,
1679-
q_scaling=self.q_scaling,
1680-
attention_window_size=attention_window_size,
1681-
attention_mask_type=attention_mask_type,
1682-
attention_mask_data=attention_mask_data,
1683-
multi_item_params=multi_item_params,
1684-
)
1689+
with nvtx_range("metadata.plan"):
1690+
plan_params = metadata.plan(
1691+
self.num_heads,
1692+
self.num_kv_heads,
1693+
self.head_dim,
1694+
q_dtype=q.dtype,
1695+
kv_dtype=k.dtype,
1696+
q_scaling=self.q_scaling,
1697+
attention_window_size=attention_window_size,
1698+
attention_mask_type=attention_mask_type,
1699+
attention_mask_data=attention_mask_data,
1700+
)
16851701
wrapper = metadata.get_ragged_prefill_wrapper(plan_params)
16861702
if isinstance(wrapper,
16871703
flashinfer.BatchPrefillWithPagedKVCacheWrapper):
@@ -1700,11 +1716,6 @@ def forward_impl(
17001716
)
17011717
return
17021718

1703-
if multi_item_part_lens is not None:
1704-
raise ValueError(
1705-
"Multi-item masking support not implemented for paged KV cache."
1706-
)
1707-
17081719
# Key and Value
17091720
kv_cache = metadata.kv_cache_manager.get_buffers(
17101721
self.layer_idx, kv_layout=metadata.kv_layout)
@@ -1901,6 +1912,5 @@ def forward(self,
19011912
output=output,
19021913
latent_cache=latent_cache,
19031914
attention_input_type=forward_args.attention_input_type,
1904-
multi_item_part_lens=forward_args.multi_item_part_lens,
19051915
)
19061916
return output

tensorrt_llm/_torch/attention_backend/interface.py

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -174,6 +174,13 @@ class AttentionMetadata:
174174
# The number of heads per kv head.
175175
num_heads_per_kv: Optional[int] = 1
176176

177+
multi_item_part_lens: Optional[list[list[int]]] = None
178+
"""Additional token layout information for multi-item scoring.
179+
180+
Aggregates `TokensPrompt.multi_item_part_lens` for all requests in the batch,
181+
see `TokensPrompt` for details.
182+
"""
183+
177184
def __post_init__(self) -> None:
178185
if self.is_cross:
179186
assert self.cross is None or self.cross is self, "Cross attention metadata should not have sub metadata"
@@ -835,13 +842,6 @@ class AttentionForwardArgs:
835842
relative_attention_max_distance: int = 0
836843
cross_kv: Optional[torch.Tensor] = None
837844

838-
multi_item_part_lens: Optional[list[list[int]]] = None
839-
"""Additional token layout information for multi-item scoring.
840-
841-
Aggregates `TokensPrompt.multi_item_part_lens` for all requests in the batch,
842-
see `TokensPrompt` for details.
843-
"""
844-
845845
latent_cache: Optional[torch.Tensor] = None
846846
q_pe: Optional[torch.Tensor] = None
847847
mrope_rotary_cos_sin: Optional[torch.Tensor] = None
@@ -984,6 +984,10 @@ def support_fused_qkv(cls) -> bool:
984984
def support_mla(cls) -> bool:
985985
return False
986986

987+
@classmethod
988+
def support_multi_item_scoring(cls) -> bool:
989+
return False
990+
987991
def create_output(self, q: torch.Tensor, **kwargs) -> List[torch.Tensor]:
988992
"""
989993
Create the output tensors for the attention operation.

0 commit comments

Comments
 (0)