11import functools
22import math
33import os
4+ import sys
45import weakref
56from dataclasses import dataclass , field
67from typing import Any , Dict , Literal , NewType , Optional , TypeAlias , cast
78
9+ if sys .version_info [:2 ] >= (3 , 12 ):
10+ from typing import override
11+ else :
12+ from typing_extensions import override
13+
814import flashinfer
915import torch
1016from flashinfer .jit .core import check_cuda_arch
1117from typing_extensions import Self
1218
1319from tensorrt_llm ._torch .pyexecutor .sampling_utils import torch_multi_arange
20+ from tensorrt_llm ._utils import nvtx_range
1421from tensorrt_llm .functional import AttentionMaskType
1522from tensorrt_llm .logger import logger
1623from tensorrt_llm .models .modeling_utils import QuantConfig
@@ -157,6 +164,9 @@ class FlashInferAttentionMetadata(AttentionMetadata):
157164 _mla_kv_len_arr_buf : Optional [torch .Tensor ] = field (init = False ,
158165 default = None )
159166
167+ _multi_item_params : Optional [FlashInferMultiItemParams ] = field (
168+ init = False , default = None )
169+
160170 def needs_plan (self , plan_params : PlanParams ) -> bool :
161171 if plan_params not in self ._plan_params_to_wrappers :
162172 return True
@@ -595,6 +605,7 @@ def _post_init_with_buffers(self, buffers) -> None:
595605 self ._mla_ragged_planned = False
596606 self ._mla_context_planned = False
597607 self ._mla_decode_planned = False
608+ self ._multi_item_params = None
598609
599610 def create_cuda_graph_metadata (self ,
600611 max_batch_size : int ,
@@ -722,6 +733,89 @@ def _plan_ragged_no_kv(
722733 ** plan_kwargs ,
723734 )
724735
736+ def _process_multi_item_part_lens (
737+ self ,
738+ multi_item_part_lens : list [list [int ]],
739+ * ,
740+ device : torch .device ,
741+ ) -> FlashInferMultiItemParams :
742+ if self .num_generations > 0 :
743+ raise ValueError (
744+ "\" multi_item_part_lens\" not supported for generation requests."
745+ )
746+ if len (multi_item_part_lens ) != self .num_contexts :
747+ raise ValueError (
748+ "\" multi_item_part_lens\" needs to be provided for all requests."
749+ )
750+
751+ prefix_len_ptr = torch .tensor (
752+ [req_part_lens [0 ] for req_part_lens in multi_item_part_lens ],
753+ pin_memory = prefer_pinned (),
754+ dtype = torch .uint32 ,
755+ ).to (device = device , non_blocking = True )
756+ token_pos_in_items_raw_lens = [ # 'raw' lengths before padding
757+ sum (req_part_lens [1 :]) + len (req_part_lens )
758+ for req_part_lens in multi_item_part_lens
759+ ]
760+ token_pos_in_items_len = max (token_pos_in_items_raw_lens )
761+ max_item_len_ptr = torch .tensor (
762+ [max (req_part_lens [1 :]) for req_part_lens in multi_item_part_lens ],
763+ pin_memory = prefer_pinned (),
764+ dtype = torch .uint16 ,
765+ ).to (device = device , non_blocking = True )
766+
767+ # token_pos_in_items_ptr is obtained by concatenating range(item_len + 1) for each item in
768+ # every request, followed by [0] (final delimiter) which is fused with padding for simplicity.
769+ range_ends = torch .tensor (
770+ [
771+ item_len + 1
772+ for req_part_lens , token_pos_in_items_raw_len in zip (
773+ multi_item_part_lens , token_pos_in_items_raw_lens )
774+ for item_len in (
775+ req_part_lens [1 :] +
776+ [token_pos_in_items_len - token_pos_in_items_raw_len ])
777+ ],
778+ pin_memory = prefer_pinned (),
779+ dtype = torch .int32 ,
780+ ).to (device = device , non_blocking = True )
781+ token_pos_in_items_ptr = torch_multi_arange (
782+ range_ends ,
783+ output_length = (token_pos_in_items_len * len (multi_item_part_lens )),
784+ )
785+ # next, mask out the padding
786+ mask_entries = torch .arange (2 , dtype = torch .uint8 ).to (
787+ device = device ,
788+ non_blocking = True ,
789+ dtype = torch .bool ,
790+ ).repeat (len (multi_item_part_lens )) # NB: .expand() does not work here
791+ mask_entry_repeats = torch .tensor (
792+ [
793+ repeat
794+ for token_pos_in_items_raw_len in token_pos_in_items_raw_lens
795+ for repeat in [
796+ token_pos_in_items_raw_len ,
797+ token_pos_in_items_len - token_pos_in_items_raw_len ,
798+ ]
799+ ],
800+ pin_memory = prefer_pinned (),
801+ dtype = torch .int32 ,
802+ ).to (device = device , non_blocking = True )
803+ padding_mask = torch .repeat_interleave (
804+ input = mask_entries ,
805+ repeats = mask_entry_repeats ,
806+ output_size = token_pos_in_items_ptr .size (0 ),
807+ )
808+ token_pos_in_items_ptr .masked_fill_ (padding_mask , 0 )
809+ token_pos_in_items_ptr = token_pos_in_items_ptr .to (dtype = torch .uint16 ,
810+ non_blocking = True )
811+
812+ return FlashInferMultiItemParams (
813+ prefix_len_ptr = prefix_len_ptr ,
814+ max_item_len_ptr = max_item_len_ptr ,
815+ token_pos_in_items_ptr = token_pos_in_items_ptr ,
816+ token_pos_in_items_len = token_pos_in_items_len ,
817+ )
818+
725819 def _clean_cached_plans (self , * , defer_plan : bool ):
726820 for plan_params in list (self ._plan_params_to_wrappers .keys ()):
727821 # Generally, plan_params with non-trivial attention masking are relevant only the
@@ -740,11 +834,18 @@ def prepare(self) -> None:
740834 if extra_attrs is None :
741835 get_global_attrs ().attention_metadata = weakref .ref (self )
742836 # start and end indices of each sequence in the ragged query
837+ assert self .seq_lens_cuda is not None
743838 torch .cumsum (self .seq_lens_cuda ,
744839 dim = 0 ,
745840 dtype = torch .int32 ,
746841 out = self ._qo_indptr [1 :self .seq_lens_cuda .size (0 ) + 1 ])
747842
843+ if self .multi_item_part_lens is not None :
844+ self ._multi_item_params = self ._process_multi_item_part_lens (
845+ self .multi_item_part_lens , device = self .seq_lens_cuda .device )
846+ else :
847+ self ._multi_item_params = None
848+
748849 if self .kv_cache_manager is None :
749850 assert self .request_ids is not None
750851 assert self .num_generations == 0 , (
@@ -761,6 +862,10 @@ def prepare(self) -> None:
761862 self ._clean_cached_plans (defer_plan = False )
762863 return
763864
865+ if self ._multi_item_params is not None :
866+ raise ValueError (
867+ "multi_item_part_lens with KV cache is not supported" )
868+
764869 # indices of used cache blocks for each sequence
765870 assert self .request_ids is not None
766871 block_ids_per_seq = self .kv_cache_manager .get_batch_cache_indices (
@@ -1009,7 +1114,6 @@ def plan(self,
10091114 q_scaling : Optional [float ] = None ,
10101115 attention_window_size : Optional [int ] = None ,
10111116 attention_mask_data : Optional [torch .Tensor ] = None ,
1012- multi_item_params : Optional [FlashInferMultiItemParams ] = None ,
10131117 flashinfer_backend : str = "fa2" ) -> PlanParams :
10141118
10151119 sm_scale = None
@@ -1027,7 +1131,7 @@ def plan(self,
10271131 if attention_window_size is not None else - 1 ,
10281132 attention_mask_type = AttentionMaskType (attention_mask_type ),
10291133 attention_mask_data = attention_mask_data ,
1030- multi_item_params = multi_item_params ,
1134+ multi_item_params = self . _multi_item_params ,
10311135 )
10321136 return self ._plan_with_params (plan_params , flashinfer_backend )
10331137
@@ -1214,6 +1318,11 @@ class FlashInferAttention(AttentionBackend[FlashInferAttentionMetadata]):
12141318 def support_mla (cls ) -> bool :
12151319 return True
12161320
1321+ @override
1322+ @classmethod
1323+ def support_multi_item_scoring (cls ) -> bool :
1324+ return True
1325+
12171326 def __init__ (
12181327 self ,
12191328 layer_idx : int ,
@@ -1247,90 +1356,6 @@ def update_quant_config(self, new_quant_config: Optional[QuantConfig]):
12471356 self .has_fp8_kv_cache = self .quant_config .layer_quant_mode .has_fp8_kv_cache (
12481357 )
12491358
1250- @staticmethod
1251- def _process_multi_item_part_lens (
1252- multi_item_part_lens : list [list [int ]],
1253- * ,
1254- metadata : FlashInferAttentionMetadata ,
1255- device : torch .device ,
1256- ) -> FlashInferMultiItemParams :
1257- if metadata .num_generations > 0 :
1258- raise ValueError (
1259- "\" multi_item_part_lens\" not supported for generation requests."
1260- )
1261- if len (multi_item_part_lens ) != metadata .num_contexts :
1262- raise ValueError (
1263- "\" multi_item_part_lens\" needs to be provided for all requests."
1264- )
1265-
1266- prefix_len_ptr = torch .tensor (
1267- [req_part_lens [0 ] for req_part_lens in multi_item_part_lens ],
1268- pin_memory = prefer_pinned (),
1269- dtype = torch .uint32 ,
1270- ).to (device = device , non_blocking = True )
1271- token_pos_in_items_raw_lens = [ # 'raw' lengths before padding
1272- sum (req_part_lens [1 :]) + len (req_part_lens )
1273- for req_part_lens in multi_item_part_lens
1274- ]
1275- token_pos_in_items_len = max (token_pos_in_items_raw_lens )
1276- max_item_len_ptr = torch .tensor (
1277- [max (req_part_lens [1 :]) for req_part_lens in multi_item_part_lens ],
1278- pin_memory = prefer_pinned (),
1279- dtype = torch .uint16 ,
1280- ).to (device = device , non_blocking = True )
1281-
1282- # token_pos_in_items_ptr is obtained by concatenating range(item_len + 1) for each item in
1283- # every request, followed by [0] (final delimiter) which is fused with padding for simplicity.
1284- range_ends = torch .tensor (
1285- [
1286- item_len + 1
1287- for req_part_lens , token_pos_in_items_raw_len in zip (
1288- multi_item_part_lens , token_pos_in_items_raw_lens )
1289- for item_len in (
1290- req_part_lens [1 :] +
1291- [token_pos_in_items_len - token_pos_in_items_raw_len ])
1292- ],
1293- pin_memory = prefer_pinned (),
1294- dtype = torch .int32 ,
1295- ).to (device = device , non_blocking = True )
1296- token_pos_in_items_ptr = torch_multi_arange (
1297- range_ends ,
1298- output_length = (token_pos_in_items_len * len (multi_item_part_lens )),
1299- )
1300- # next, mask out the padding
1301- mask_entries = torch .arange (2 , dtype = torch .uint8 ).to (
1302- device = device ,
1303- non_blocking = True ,
1304- dtype = torch .bool ,
1305- ).repeat (len (multi_item_part_lens )) # NB: .expand() does not work here
1306- mask_entry_repeats = torch .tensor (
1307- [
1308- repeat
1309- for token_pos_in_items_raw_len in token_pos_in_items_raw_lens
1310- for repeat in [
1311- token_pos_in_items_raw_len ,
1312- token_pos_in_items_len - token_pos_in_items_raw_len ,
1313- ]
1314- ],
1315- pin_memory = prefer_pinned (),
1316- dtype = torch .int32 ,
1317- ).to (device = device , non_blocking = True )
1318- padding_mask = torch .repeat_interleave (
1319- input = mask_entries ,
1320- repeats = mask_entry_repeats ,
1321- output_size = token_pos_in_items_ptr .size (0 ),
1322- )
1323- token_pos_in_items_ptr .masked_fill_ (padding_mask , 0 )
1324- token_pos_in_items_ptr = token_pos_in_items_ptr .to (dtype = torch .uint16 ,
1325- non_blocking = True )
1326-
1327- return FlashInferMultiItemParams (
1328- prefix_len_ptr = prefix_len_ptr ,
1329- max_item_len_ptr = max_item_len_ptr ,
1330- token_pos_in_items_ptr = token_pos_in_items_ptr ,
1331- token_pos_in_items_len = token_pos_in_items_len ,
1332- )
1333-
13341359 def mla_rope_generation (
13351360 self ,
13361361 fused_q : torch .Tensor ,
@@ -1610,7 +1635,6 @@ def forward_impl(
16101635 output : torch .Tensor ,
16111636 attention_mask_data : Optional [torch .Tensor ] = None ,
16121637 attention_window_size : Optional [int ] = None ,
1613- multi_item_part_lens : Optional [list [list [int ]]] = None ,
16141638 latent_cache : Optional [torch .Tensor ] = None ,
16151639 attention_input_type : AttentionInputType = AttentionInputType .mixed ,
16161640 ) -> None :
@@ -1648,14 +1672,6 @@ def forward_impl(
16481672 # Query
16491673 q = q .view (- 1 , self .num_heads , self .head_dim )
16501674
1651- multi_item_params : FlashInferMultiItemParams | None = None
1652- if multi_item_part_lens is not None :
1653- multi_item_params = self ._process_multi_item_part_lens (
1654- multi_item_part_lens ,
1655- metadata = metadata ,
1656- device = q .device ,
1657- )
1658-
16591675 if metadata .kv_cache_manager is None :
16601676 assert k is not None and v is not None , (
16611677 "FlashInfer without a KV cache manager requires key/value tensors."
@@ -1670,18 +1686,18 @@ def forward_impl(
16701686 assert v .shape == (q .size (0 ), self .num_kv_heads * self .head_dim )
16711687 k = k .view (- 1 , self .num_kv_heads , self .head_dim )
16721688 v = v .view (- 1 , self .num_kv_heads , self .head_dim )
1673- plan_params = metadata .plan (
1674- self . num_heads ,
1675- self .num_kv_heads ,
1676- self .head_dim ,
1677- q_dtype = q . dtype ,
1678- kv_dtype = k .dtype ,
1679- q_scaling = self . q_scaling ,
1680- attention_window_size = attention_window_size ,
1681- attention_mask_type = attention_mask_type ,
1682- attention_mask_data = attention_mask_data ,
1683- multi_item_params = multi_item_params ,
1684- )
1689+ with nvtx_range ( " metadata.plan" ):
1690+ plan_params = metadata . plan (
1691+ self .num_heads ,
1692+ self .num_kv_heads ,
1693+ self . head_dim ,
1694+ q_dtype = q .dtype ,
1695+ kv_dtype = k . dtype ,
1696+ q_scaling = self . q_scaling ,
1697+ attention_window_size = attention_window_size ,
1698+ attention_mask_type = attention_mask_type ,
1699+ attention_mask_data = attention_mask_data ,
1700+ )
16851701 wrapper = metadata .get_ragged_prefill_wrapper (plan_params )
16861702 if isinstance (wrapper ,
16871703 flashinfer .BatchPrefillWithPagedKVCacheWrapper ):
@@ -1700,11 +1716,6 @@ def forward_impl(
17001716 )
17011717 return
17021718
1703- if multi_item_part_lens is not None :
1704- raise ValueError (
1705- "Multi-item masking support not implemented for paged KV cache."
1706- )
1707-
17081719 # Key and Value
17091720 kv_cache = metadata .kv_cache_manager .get_buffers (
17101721 self .layer_idx , kv_layout = metadata .kv_layout )
@@ -1901,6 +1912,5 @@ def forward(self,
19011912 output = output ,
19021913 latent_cache = latent_cache ,
19031914 attention_input_type = forward_args .attention_input_type ,
1904- multi_item_part_lens = forward_args .multi_item_part_lens ,
19051915 )
19061916 return output
0 commit comments