From 11414d659f707d91943b68dc9519b2f59bb56b5e Mon Sep 17 00:00:00 2001 From: ixlmar <206748156+ixlmar@users.noreply.github.com> Date: Tue, 16 Jun 2026 10:04:52 +0000 Subject: [PATCH 01/10] perf: reuse multi-item scoring position_ids and params Signed-off-by: ixlmar <206748156+ixlmar@users.noreply.github.com> --- .../_torch/attention_backend/flashinfer.py | 229 +++++++++--------- .../_torch/attention_backend/interface.py | 9 +- .../attention_backend/star_flashinfer.py | 12 +- .../_torch/attention_backend/trtllm.py | 22 +- .../_torch/attention_backend/vanilla.py | 13 +- tensorrt_llm/_torch/modules/attention.py | 24 -- .../_torch/pyexecutor/model_engine.py | 10 +- tensorrt_llm/llmapi/llm.py | 34 ++- 8 files changed, 185 insertions(+), 168 deletions(-) diff --git a/tensorrt_llm/_torch/attention_backend/flashinfer.py b/tensorrt_llm/_torch/attention_backend/flashinfer.py index 7ced64dec19a..574568200380 100644 --- a/tensorrt_llm/_torch/attention_backend/flashinfer.py +++ b/tensorrt_llm/_torch/attention_backend/flashinfer.py @@ -11,6 +11,7 @@ from typing_extensions import Self from tensorrt_llm._torch.pyexecutor.sampling_utils import torch_multi_arange +from tensorrt_llm._utils import nvtx_range from tensorrt_llm.functional import AttentionMaskType from tensorrt_llm.logger import logger from tensorrt_llm.models.modeling_utils import QuantConfig @@ -157,6 +158,9 @@ class FlashInferAttentionMetadata(AttentionMetadata): _mla_kv_len_arr_buf: Optional[torch.Tensor] = field(init=False, default=None) + _multi_item_params: Optional[FlashInferMultiItemParams] = field( + init=False, default=None) + def needs_plan(self, plan_params: PlanParams) -> bool: if plan_params not in self._plan_params_to_wrappers: return True @@ -595,6 +599,7 @@ def _post_init_with_buffers(self, buffers) -> None: self._mla_ragged_planned = False self._mla_context_planned = False self._mla_decode_planned = False + self._multi_item_params = None def create_cuda_graph_metadata(self, max_batch_size: int, @@ -722,6 +727,89 @@ def _plan_ragged_no_kv( **plan_kwargs, ) + def _process_multi_item_part_lens( + self, + multi_item_part_lens: list[list[int]], + *, + device: torch.device, + ) -> FlashInferMultiItemParams: + if self.num_generations > 0: + raise ValueError( + "\"multi_item_part_lens\" not supported for generation requests." + ) + if len(multi_item_part_lens) != self.num_contexts: + raise ValueError( + "\"multi_item_part_lens\" needs to be provided for all requests." + ) + + prefix_len_ptr = torch.tensor( + [req_part_lens[0] for req_part_lens in multi_item_part_lens], + pin_memory=prefer_pinned(), + dtype=torch.uint32, + ).to(device=device, non_blocking=True) + token_pos_in_items_raw_lens = [ # 'raw' lengths before padding + sum(req_part_lens[1:]) + len(req_part_lens) + for req_part_lens in multi_item_part_lens + ] + token_pos_in_items_len = max(token_pos_in_items_raw_lens) + max_item_len_ptr = torch.tensor( + [max(req_part_lens[1:]) for req_part_lens in multi_item_part_lens], + pin_memory=prefer_pinned(), + dtype=torch.uint16, + ).to(device=device, non_blocking=True) + + # token_pos_in_items_ptr is obtained by concatenating range(item_len + 1) for each item in + # every request, followed by [0] (final delimiter) which is fused with padding for simplicity. + range_ends = torch.tensor( + [ + item_len + 1 + for req_part_lens, token_pos_in_items_raw_len in zip( + multi_item_part_lens, token_pos_in_items_raw_lens) + for item_len in ( + req_part_lens[1:] + + [token_pos_in_items_len - token_pos_in_items_raw_len]) + ], + pin_memory=prefer_pinned(), + dtype=torch.int32, + ).to(device=device, non_blocking=True) + token_pos_in_items_ptr = torch_multi_arange( + range_ends, + output_length=(token_pos_in_items_len * len(multi_item_part_lens)), + ) + # next, mask out the padding + mask_entries = torch.arange(2, dtype=torch.uint8).to( + device=device, + non_blocking=True, + dtype=torch.bool, + ).repeat(len(multi_item_part_lens)) # NB: .expand() does not work here + mask_entry_repeats = torch.tensor( + [ + repeat + for token_pos_in_items_raw_len in token_pos_in_items_raw_lens + for repeat in [ + token_pos_in_items_raw_len, + token_pos_in_items_len - token_pos_in_items_raw_len, + ] + ], + pin_memory=prefer_pinned(), + dtype=torch.int32, + ).to(device=device, non_blocking=True) + padding_mask = torch.repeat_interleave( + input=mask_entries, + repeats=mask_entry_repeats, + output_size=token_pos_in_items_ptr.size(0), + ) + token_pos_in_items_ptr.masked_fill_(padding_mask, 0) + token_pos_in_items_ptr = token_pos_in_items_ptr.to(dtype=torch.uint16, + non_blocking=True) + + return FlashInferMultiItemParams( + prefix_len_ptr=prefix_len_ptr, + max_item_len_ptr=max_item_len_ptr, + token_pos_in_items_ptr=token_pos_in_items_ptr, + token_pos_in_items_len=token_pos_in_items_len, + ) + def _clean_cached_plans(self, *, defer_plan: bool): for plan_params in list(self._plan_params_to_wrappers.keys()): # Generally, plan_params with non-trivial attention masking are relevant only the @@ -734,8 +822,10 @@ def _clean_cached_plans(self, *, defer_plan: bool): else: del self._plan_params_to_wrappers[plan_params] - def prepare(self) -> None: - super().prepare() + def prepare(self, + *, + multi_item_part_lens: list[list[int]] | None = None) -> None: + super().prepare(multi_item_part_lens=multi_item_part_lens) extra_attrs = get_model_extra_attrs() if extra_attrs is None: get_global_attrs().attention_metadata = weakref.ref(self) @@ -758,9 +848,18 @@ def prepare(self) -> None: n = self.num_seqs self._cached_token_lens[:n].zero_() self.num_ctx_cached_tokens = 0 + if multi_item_part_lens is not None: + self._multi_item_params = self._process_multi_item_part_lens( + multi_item_part_lens, device=self.seq_lens_cuda.device) + else: + self._multi_item_params = None self._clean_cached_plans(defer_plan=False) return + if multi_item_part_lens is not None: + raise ValueError( + "multi_item_part_lens with KV cache is not supported") + # indices of used cache blocks for each sequence assert self.request_ids is not None block_ids_per_seq = self.kv_cache_manager.get_batch_cache_indices( @@ -1009,7 +1108,6 @@ def plan(self, q_scaling: Optional[float] = None, attention_window_size: Optional[int] = None, attention_mask_data: Optional[torch.Tensor] = None, - multi_item_params: Optional[FlashInferMultiItemParams] = None, flashinfer_backend: str = "fa2") -> PlanParams: sm_scale = None @@ -1027,7 +1125,7 @@ def plan(self, if attention_window_size is not None else -1, attention_mask_type=AttentionMaskType(attention_mask_type), attention_mask_data=attention_mask_data, - multi_item_params=multi_item_params, + multi_item_params=self._multi_item_params, ) return self._plan_with_params(plan_params, flashinfer_backend) @@ -1247,90 +1345,6 @@ def update_quant_config(self, new_quant_config: Optional[QuantConfig]): self.has_fp8_kv_cache = self.quant_config.layer_quant_mode.has_fp8_kv_cache( ) - @staticmethod - def _process_multi_item_part_lens( - multi_item_part_lens: list[list[int]], - *, - metadata: FlashInferAttentionMetadata, - device: torch.device, - ) -> FlashInferMultiItemParams: - if metadata.num_generations > 0: - raise ValueError( - "\"multi_item_part_lens\" not supported for generation requests." - ) - if len(multi_item_part_lens) != metadata.num_contexts: - raise ValueError( - "\"multi_item_part_lens\" needs to be provided for all requests." - ) - - prefix_len_ptr = torch.tensor( - [req_part_lens[0] for req_part_lens in multi_item_part_lens], - pin_memory=prefer_pinned(), - dtype=torch.uint32, - ).to(device=device, non_blocking=True) - token_pos_in_items_raw_lens = [ # 'raw' lengths before padding - sum(req_part_lens[1:]) + len(req_part_lens) - for req_part_lens in multi_item_part_lens - ] - token_pos_in_items_len = max(token_pos_in_items_raw_lens) - max_item_len_ptr = torch.tensor( - [max(req_part_lens[1:]) for req_part_lens in multi_item_part_lens], - pin_memory=prefer_pinned(), - dtype=torch.uint16, - ).to(device=device, non_blocking=True) - - # token_pos_in_items_ptr is obtained by concatenating range(item_len + 1) for each item in - # every request, followed by [0] (final delimiter) which is fused with padding for simplicity. - range_ends = torch.tensor( - [ - item_len + 1 - for req_part_lens, token_pos_in_items_raw_len in zip( - multi_item_part_lens, token_pos_in_items_raw_lens) - for item_len in ( - req_part_lens[1:] + - [token_pos_in_items_len - token_pos_in_items_raw_len]) - ], - pin_memory=prefer_pinned(), - dtype=torch.int32, - ).to(device=device, non_blocking=True) - token_pos_in_items_ptr = torch_multi_arange( - range_ends, - output_length=(token_pos_in_items_len * len(multi_item_part_lens)), - ) - # next, mask out the padding - mask_entries = torch.arange(2, dtype=torch.uint8).to( - device=device, - non_blocking=True, - dtype=torch.bool, - ).repeat(len(multi_item_part_lens)) # NB: .expand() does not work here - mask_entry_repeats = torch.tensor( - [ - repeat - for token_pos_in_items_raw_len in token_pos_in_items_raw_lens - for repeat in [ - token_pos_in_items_raw_len, - token_pos_in_items_len - token_pos_in_items_raw_len, - ] - ], - pin_memory=prefer_pinned(), - dtype=torch.int32, - ).to(device=device, non_blocking=True) - padding_mask = torch.repeat_interleave( - input=mask_entries, - repeats=mask_entry_repeats, - output_size=token_pos_in_items_ptr.size(0), - ) - token_pos_in_items_ptr.masked_fill_(padding_mask, 0) - token_pos_in_items_ptr = token_pos_in_items_ptr.to(dtype=torch.uint16, - non_blocking=True) - - return FlashInferMultiItemParams( - prefix_len_ptr=prefix_len_ptr, - max_item_len_ptr=max_item_len_ptr, - token_pos_in_items_ptr=token_pos_in_items_ptr, - token_pos_in_items_len=token_pos_in_items_len, - ) - def mla_rope_generation( self, fused_q: torch.Tensor, @@ -1610,7 +1624,6 @@ def forward_impl( output: torch.Tensor, attention_mask_data: Optional[torch.Tensor] = None, attention_window_size: Optional[int] = None, - multi_item_part_lens: Optional[list[list[int]]] = None, latent_cache: Optional[torch.Tensor] = None, attention_input_type: AttentionInputType = AttentionInputType.mixed, ) -> None: @@ -1648,14 +1661,6 @@ def forward_impl( # Query q = q.view(-1, self.num_heads, self.head_dim) - multi_item_params: FlashInferMultiItemParams | None = None - if multi_item_part_lens is not None: - multi_item_params = self._process_multi_item_part_lens( - multi_item_part_lens, - metadata=metadata, - device=q.device, - ) - if metadata.kv_cache_manager is None: assert k is not None and v is not None, ( "FlashInfer without a KV cache manager requires key/value tensors." @@ -1670,18 +1675,18 @@ def forward_impl( assert v.shape == (q.size(0), self.num_kv_heads * self.head_dim) k = k.view(-1, self.num_kv_heads, self.head_dim) v = v.view(-1, self.num_kv_heads, self.head_dim) - plan_params = metadata.plan( - self.num_heads, - self.num_kv_heads, - self.head_dim, - q_dtype=q.dtype, - kv_dtype=k.dtype, - q_scaling=self.q_scaling, - attention_window_size=attention_window_size, - attention_mask_type=attention_mask_type, - attention_mask_data=attention_mask_data, - multi_item_params=multi_item_params, - ) + with nvtx_range("metadata.plan"): + plan_params = metadata.plan( + self.num_heads, + self.num_kv_heads, + self.head_dim, + q_dtype=q.dtype, + kv_dtype=k.dtype, + q_scaling=self.q_scaling, + attention_window_size=attention_window_size, + attention_mask_type=attention_mask_type, + attention_mask_data=attention_mask_data, + ) wrapper = metadata.get_ragged_prefill_wrapper(plan_params) if isinstance(wrapper, flashinfer.BatchPrefillWithPagedKVCacheWrapper): @@ -1700,11 +1705,6 @@ def forward_impl( ) return - if multi_item_part_lens is not None: - raise ValueError( - "Multi-item masking support not implemented for paged KV cache." - ) - # Key and Value kv_cache = metadata.kv_cache_manager.get_buffers( self.layer_idx, kv_layout=metadata.kv_layout) @@ -1901,6 +1901,5 @@ def forward(self, output=output, latent_cache=latent_cache, attention_input_type=forward_args.attention_input_type, - multi_item_part_lens=forward_args.multi_item_part_lens, ) return output diff --git a/tensorrt_llm/_torch/attention_backend/interface.py b/tensorrt_llm/_torch/attention_backend/interface.py index c1630bbf5041..d80f75ab2b2d 100644 --- a/tensorrt_llm/_torch/attention_backend/interface.py +++ b/tensorrt_llm/_torch/attention_backend/interface.py @@ -308,7 +308,7 @@ def num_ctx_tokens(self) -> int: def num_tokens(self) -> int: return self._num_tokens - def prepare(self): + def prepare(self, *, multi_item_part_lens: list[list[int]] | None = None): """ Hook to be called before the forward step of the model. """ @@ -835,13 +835,6 @@ class AttentionForwardArgs: relative_attention_max_distance: int = 0 cross_kv: Optional[torch.Tensor] = None - multi_item_part_lens: Optional[list[list[int]]] = None - """Additional token layout information for multi-item scoring. - - Aggregates `TokensPrompt.multi_item_part_lens` for all requests in the batch, - see `TokensPrompt` for details. - """ - latent_cache: Optional[torch.Tensor] = None q_pe: Optional[torch.Tensor] = None mrope_rotary_cos_sin: Optional[torch.Tensor] = None diff --git a/tensorrt_llm/_torch/attention_backend/star_flashinfer.py b/tensorrt_llm/_torch/attention_backend/star_flashinfer.py index 8a396ab09fbc..23fec0a898b8 100644 --- a/tensorrt_llm/_torch/attention_backend/star_flashinfer.py +++ b/tensorrt_llm/_torch/attention_backend/star_flashinfer.py @@ -87,7 +87,13 @@ def num_generations(self) -> int: """ return self.seq_lens.shape[0] - self.num_contexts - self.num_queries - def prepare(self) -> None: + def prepare(self, + *, + multi_item_part_lens: list[list[int]] | None = None) -> None: + if multi_item_part_lens is not None: + raise ValueError( + "Star Attention does not support multi-item scoring") + context_lens = self.context_lens query_lens = self.query_lens # indices of used cache blocks for each sequence @@ -322,10 +328,6 @@ def forward(self, ) assert not metadata.is_cross, "Star Attention does not support cross attention yet." - if forward_args.multi_item_part_lens is not None: - raise ValueError( - "Star Attention does not support multi-item scoring") - q = q.view(-1, self.num_heads, self.head_dim) k = k.view(-1, self.num_kv_heads, self.head_dim) v = v.view(-1, self.num_kv_heads, self.head_dim) diff --git a/tensorrt_llm/_torch/attention_backend/trtllm.py b/tensorrt_llm/_torch/attention_backend/trtllm.py index 7ada30a7c708..7d839a3e2909 100644 --- a/tensorrt_llm/_torch/attention_backend/trtllm.py +++ b/tensorrt_llm/_torch/attention_backend/trtllm.py @@ -489,8 +489,13 @@ def update_helix_param( self.helix_is_inactive_rank[:batch_size].copy_( self.helix_is_inactive_rank_cpu[:batch_size], non_blocking=True) - def prepare(self) -> None: - super().prepare() + def prepare(self, + *, + multi_item_part_lens: list[list[int]] | None = None) -> None: + if multi_item_part_lens is not None: + raise ValueError( + "TRT-LLM Attention does not support multi-item scoring") + super().prepare(multi_item_part_lens=multi_item_part_lens) extra_attrs = get_model_extra_attrs() # If model extra attrs is set, attention_metadata is setup in executor. if extra_attrs is None: @@ -589,8 +594,15 @@ def prepare(self) -> None: self.host_request_types_runtime = self.host_request_types[:self. num_seqs] - def prepare_encoder_only(self) -> None: + def prepare_encoder_only( + self, + *, + multi_item_part_lens: list[list[int]] | None = None) -> None: """Fast path for encoder-only forward (eager + CUDA graph capture).""" + if multi_item_part_lens is not None: + raise ValueError( + "TRT-LLM Attention does not support multi-item scoring") + extra_attrs = get_model_extra_attrs() if extra_attrs is None: get_global_attrs().attention_metadata = weakref.ref(self) @@ -1750,10 +1762,6 @@ def forward( # Cross-attention uses the THOP path; the trtllm-gen backend API does # not carry encoder K/V tensors yet. - if forward_args.multi_item_part_lens is not None: - raise ValueError( - "TRT-LLM Attention does not support multi-item scoring") - # SM90 forces ``use_paged_context_fmha`` on for correctness # (https://nvbugs/5624818). if get_sm_version() == 90: diff --git a/tensorrt_llm/_torch/attention_backend/vanilla.py b/tensorrt_llm/_torch/attention_backend/vanilla.py index 511dee84b743..7155d248b45e 100644 --- a/tensorrt_llm/_torch/attention_backend/vanilla.py +++ b/tensorrt_llm/_torch/attention_backend/vanilla.py @@ -60,8 +60,13 @@ def generate_sliding_window_mask(batch_size: int, target_length: int, class VanillaAttentionMetadata(AttentionMetadata): - def prepare(self) -> None: - super().prepare() + def prepare(self, + *, + multi_item_part_lens: list[list[int]] | None = None) -> None: + if multi_item_part_lens is not None: + raise ValueError( + "Vanilla Attention does not support multi-item scoring") + super().prepare(multi_item_part_lens=multi_item_part_lens) # indices of used cache blocks for each sequence assert self.request_ids is not None self.block_ids_per_seq = self.kv_cache_manager.get_batch_cache_indices( @@ -488,10 +493,6 @@ def forward(self, **kwargs) -> torch.Tensor: forward_args = merge_attention_forward_args(forward_args, kwargs) - if forward_args.multi_item_part_lens is not None: - raise ValueError( - "Vanilla Attention does not support multi-item scoring") - if metadata.kv_cache_manager is None: # NOTE: WAR for no kv cache attn e.g. BERT, # try to separate the kv cache estimation path from no kv cache attn. diff --git a/tensorrt_llm/_torch/modules/attention.py b/tensorrt_llm/_torch/modules/attention.py index 28d5cff5c3d8..cc0694728829 100644 --- a/tensorrt_llm/_torch/modules/attention.py +++ b/tensorrt_llm/_torch/modules/attention.py @@ -698,7 +698,6 @@ def _attn_impl( relative_attention_bias: Optional[torch.Tensor] = None, relative_attention_max_distance: int = 0, has_lora: bool = False, - multi_item_part_lens: Optional[list[list[int]]] = None, ): num_tokens = attn_metadata.num_tokens @@ -738,7 +737,6 @@ def _attn_impl( relative_attention_bias=relative_attention_bias, relative_attention_max_distance= relative_attention_max_distance, - multi_item_part_lens=multi_item_part_lens, )) if isinstance(attn_output, tuple): attn_output = attn_output[0] @@ -787,7 +785,6 @@ def _attn_impl( attention_sinks=attention_sinks, relative_attention_bias=relative_attention_bias, relative_attention_max_distance=relative_attention_max_distance, - multi_item_part_lens=multi_item_part_lens, )) if isinstance(attn_output, tuple): assert len( @@ -810,7 +807,6 @@ def forward_impl( relative_attention_bias: Optional[torch.Tensor] = None, relative_attention_max_distance: int = 0, has_lora: bool = False, - multi_item_part_lens: Optional[list[list[int]]] = None, ): mrope_rotary_cos_sin = None mrope_position_deltas = None @@ -863,7 +859,6 @@ def forward_impl( relative_attention_bias=relative_attention_bias, relative_attention_max_distance=relative_attention_max_distance, has_lora=has_lora, - multi_item_part_lens=multi_item_part_lens, ) if output_sf is not None: output = Fp4QuantizedTensor(output, output_sf) @@ -884,7 +879,6 @@ def forward( attention_sinks: Optional[torch.Tensor] = None, relative_attention_bias: Optional[torch.Tensor] = None, relative_attention_max_distance: int = 0, - multi_item_part_lens: Optional[list[list[int]]] = None, **kwargs, ) -> torch.Tensor: """ @@ -944,23 +938,6 @@ def forward( position_ids = self._adjust_position_ids_for_spec_dec( position_ids, attn_metadata) - if multi_item_part_lens is not None: - # adjust RoPE positions for multi-item scoring - current_idx = 0 - for req_multi_item_part_lens in multi_item_part_lens: - req_prefix_len, *req_multi_item_part_lens = req_multi_item_part_lens - # RoPE for prefix does not need updating and RoPE for delimiter does not matter - current_idx += req_prefix_len + 1 - for item_len in req_multi_item_part_lens: - next_idx = current_idx + item_len - position_ids[0, current_idx:next_idx].copy_( - torch.arange(req_prefix_len, - req_prefix_len + item_len, - dtype=position_ids.dtype, - device=position_ids.device), - non_blocking=True) - current_idx = next_idx + 1 # RoPE for delimiter does not matter - q, k, v = self.apply_rope(q, k, v, position_ids) q, k, v = self.convert_qkv(q, k, v) @@ -984,7 +961,6 @@ def forward( relative_attention_bias=relative_attention_bias, relative_attention_max_distance=relative_attention_max_distance, has_lora=bool(lora_params), - multi_item_part_lens=multi_item_part_lens, ) if self.attn_output_gate: diff --git a/tensorrt_llm/_torch/pyexecutor/model_engine.py b/tensorrt_llm/_torch/pyexecutor/model_engine.py index e1551c88edd8..685bbf7ec058 100644 --- a/tensorrt_llm/_torch/pyexecutor/model_engine.py +++ b/tensorrt_llm/_torch/pyexecutor/model_engine.py @@ -4499,6 +4499,7 @@ def _prepare_encoder_inputs( input_ids = inputs['input_ids'] seq_lens = inputs['seq_lens'] # Only seq_lens includes padding position_ids = inputs.get('position_ids') + multi_item_part_lens = inputs.get('multi_item_part_lens') actual_num_tokens = len(input_ids) batch_size = len(seq_lens) @@ -4527,9 +4528,10 @@ def _prepare_encoder_inputs( attn_metadata.max_seq_len = self.max_seq_len attn_metadata.request_ids = list(range(batch_size)) if hasattr(attn_metadata, 'prepare_encoder_only'): - attn_metadata.prepare_encoder_only() + attn_metadata.prepare_encoder_only( + multi_item_part_lens=multi_item_part_lens) else: - attn_metadata.prepare() + attn_metadata.prepare(multi_item_part_lens=multi_item_part_lens) self.input_ids_cuda[:actual_num_tokens].copy_(input_ids_t, non_blocking=True) @@ -4548,6 +4550,10 @@ def _prepare_encoder_inputs( # CUDA graph hit path. assert self.encoder_cuda_graph_runner.enabled, "Encoder CUDA graph runner is not enabled" + # NB: As of 06/10/2026, the multi-item scoring arguments lacked '_buf' counterparts (cf., e.g., + # https://github.com/flashinfer-ai/flashinfer/blob/2aa1d49cf140d73ccdd3761051c5f2944406cb83/flashinfer/prefill.py#L1622 ). + assert multi_item_part_lens is None, "multi-item scoring with CUDA graph not implemented" + attn_metadata.prepare_encoder_cuda_graph_replay(seq_lens, padded_num_tokens) diff --git a/tensorrt_llm/llmapi/llm.py b/tensorrt_llm/llmapi/llm.py index 2659d46449e2..741e18846af5 100644 --- a/tensorrt_llm/llmapi/llm.py +++ b/tensorrt_llm/llmapi/llm.py @@ -16,6 +16,7 @@ from tqdm import tqdm from transformers import PreTrainedTokenizerBase +from tensorrt_llm._torch.pyexecutor.sampling_utils import torch_multi_arange from tensorrt_llm._utils import mpi_disabled from tensorrt_llm.inputs.multimodal import (DisaggPrefillMultimodalInputs, MultimodalParams) @@ -23,7 +24,7 @@ from tensorrt_llm.llmapi import tracing from tensorrt_llm.metrics.enums import MetricNames -from .._utils import nvtx_range_debug +from .._utils import nvtx_range_debug, prefer_pinned from ..bindings import executor as tllm from ..bindings import steady_clock_now from ..builder import EngineConfig @@ -947,6 +948,7 @@ def preprocess( ) @set_api_status("prototype") + @torch.inference_mode() def encode( self, inputs: Union[PromptInputs, Sequence[PromptInputs]], @@ -1086,6 +1088,36 @@ def encode( ) forward_inputs["multi_item_part_lens"] = batch_multi_item_part_lens + # Scoring items have overlapping position IDs. Position IDs of delimiters + # are irrelevant. + starts_cuda = torch.tensor( + [ + start for multi_item_part_lens in batch_multi_item_part_lens + for start in [0] + [multi_item_part_lens[0]] * + (len(multi_item_part_lens) - 1) + ], + pin_memory=prefer_pinned(), + dtype=torch.int32, + ).to("cuda", non_blocking=True) # uses current device + ends_cuda = torch.tensor( + [ + end + 1 + for multi_item_part_lens in batch_multi_item_part_lens + for end in [multi_item_part_lens[0]] + [ + multi_item_part_lens[0] + item_len + for item_len in multi_item_part_lens[1:] + ] + ], + pin_memory=prefer_pinned(), + dtype=torch.int32, + ).to("cuda", non_blocking=True) + position_ids_cuda = torch_multi_arange( + starts=starts_cuda, + ends=ends_cuda, + output_length=len(flat_token_ids), + ) + forward_inputs["position_ids"] = position_ids_cuda + # Single forward pass outputs = self._encoder_executor.batch_forward(forward_inputs, **forward_kwargs) From ecd8accb1298b66b5034fb1865a321a0d635ece8 Mon Sep 17 00:00:00 2001 From: ixlmar <206748156+ixlmar@users.noreply.github.com> Date: Tue, 16 Jun 2026 11:45:26 +0000 Subject: [PATCH 02/10] update comment Signed-off-by: ixlmar <206748156+ixlmar@users.noreply.github.com> --- tensorrt_llm/_torch/pyexecutor/model_engine.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensorrt_llm/_torch/pyexecutor/model_engine.py b/tensorrt_llm/_torch/pyexecutor/model_engine.py index 685bbf7ec058..dc2e2b82a107 100644 --- a/tensorrt_llm/_torch/pyexecutor/model_engine.py +++ b/tensorrt_llm/_torch/pyexecutor/model_engine.py @@ -4550,7 +4550,7 @@ def _prepare_encoder_inputs( # CUDA graph hit path. assert self.encoder_cuda_graph_runner.enabled, "Encoder CUDA graph runner is not enabled" - # NB: As of 06/10/2026, the multi-item scoring arguments lacked '_buf' counterparts (cf., e.g., + # NB: The multi-item scoring arguments lack '_buf' counterparts (cf., e.g., # https://github.com/flashinfer-ai/flashinfer/blob/2aa1d49cf140d73ccdd3761051c5f2944406cb83/flashinfer/prefill.py#L1622 ). assert multi_item_part_lens is None, "multi-item scoring with CUDA graph not implemented" From 44ff7ffb35dcbf238ce166fb50a7067929e14258 Mon Sep 17 00:00:00 2001 From: ixlmar <206748156+ixlmar@users.noreply.github.com> Date: Tue, 16 Jun 2026 11:45:51 +0000 Subject: [PATCH 03/10] graceful fallback to eager Signed-off-by: ixlmar <206748156+ixlmar@users.noreply.github.com> --- tensorrt_llm/_torch/pyexecutor/cuda_graph_runner.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/tensorrt_llm/_torch/pyexecutor/cuda_graph_runner.py b/tensorrt_llm/_torch/pyexecutor/cuda_graph_runner.py index 478639035c62..77cf6edeea97 100644 --- a/tensorrt_llm/_torch/pyexecutor/cuda_graph_runner.py +++ b/tensorrt_llm/_torch/pyexecutor/cuda_graph_runner.py @@ -830,6 +830,14 @@ def maybe_get_cuda_graph( if not self._capture_allowed: return None, None + if "multi_item_part_lens" in inputs: + # See model_engine.py for more details + logger.warning_once( + "Encoder CUDA graph does not support multi-item scoring; " + "falling back to eager.", + key="encoder_cuda_graph_multi_item_scoring_warning") + return None, None + if attn_metadata.has_cross_sub_metadata: logger.warning_once( "Encoder CUDA graph does not support cross-attention metadata; " From 1d91145c996589ca8ee1b3ca0d067eb4a6a41379 Mon Sep 17 00:00:00 2001 From: ixlmar <206748156+ixlmar@users.noreply.github.com> Date: Wed, 17 Jun 2026 09:15:34 +0000 Subject: [PATCH 04/10] address review comments Signed-off-by: ixlmar <206748156+ixlmar@users.noreply.github.com> --- .../_torch/attention_backend/flashinfer.py | 38 ++++++++---- .../_torch/attention_backend/interface.py | 19 +++++- .../attention_backend/star_flashinfer.py | 8 +-- .../_torch/attention_backend/trtllm.py | 18 +----- .../_torch/attention_backend/vanilla.py | 9 +-- .../_torch/pyexecutor/model_engine.py | 58 ++++++++++++++++--- tensorrt_llm/llmapi/llm.py | 34 +---------- .../llmapi/test_llm_encode_multi_item.py | 2 +- 8 files changed, 103 insertions(+), 83 deletions(-) diff --git a/tensorrt_llm/_torch/attention_backend/flashinfer.py b/tensorrt_llm/_torch/attention_backend/flashinfer.py index 574568200380..e0516d79f454 100644 --- a/tensorrt_llm/_torch/attention_backend/flashinfer.py +++ b/tensorrt_llm/_torch/attention_backend/flashinfer.py @@ -158,8 +158,11 @@ class FlashInferAttentionMetadata(AttentionMetadata): _mla_kv_len_arr_buf: Optional[torch.Tensor] = field(init=False, default=None) + _multi_item_part_lens: Optional[list[list[int]]] = field(init=False, + default=None) _multi_item_params: Optional[FlashInferMultiItemParams] = field( init=False, default=None) + _multi_item_params_needs_refresh: bool = field(init=False, default=False) def needs_plan(self, plan_params: PlanParams) -> bool: if plan_params not in self._plan_params_to_wrappers: @@ -599,7 +602,9 @@ def _post_init_with_buffers(self, buffers) -> None: self._mla_ragged_planned = False self._mla_context_planned = False self._mla_decode_planned = False + self._multi_item_part_lens = None self._multi_item_params = None + self._multi_item_params_needs_refresh = False def create_cuda_graph_metadata(self, max_batch_size: int, @@ -822,19 +827,37 @@ def _clean_cached_plans(self, *, defer_plan: bool): else: del self._plan_params_to_wrappers[plan_params] - def prepare(self, - *, - multi_item_part_lens: list[list[int]] | None = None) -> None: - super().prepare(multi_item_part_lens=multi_item_part_lens) + @property + def multi_item_part_lens(self) -> Optional[list[list[int]]]: + return self._multi_item_part_lens + + @multi_item_part_lens.setter + def multi_item_part_lens(self, + multi_item_part_lens: Optional[list[list[int]]]): + self._multi_item_part_lens = multi_item_part_lens + self._multi_item_params_needs_refresh = True + + def prepare(self) -> None: + super().prepare() extra_attrs = get_model_extra_attrs() if extra_attrs is None: get_global_attrs().attention_metadata = weakref.ref(self) # start and end indices of each sequence in the ragged query + assert self.seq_lens_cuda is not None torch.cumsum(self.seq_lens_cuda, dim=0, dtype=torch.int32, out=self._qo_indptr[1:self.seq_lens_cuda.size(0) + 1]) + if self._multi_item_params_needs_refresh: + if self._multi_item_part_lens is not None: + self._multi_item_params = self._process_multi_item_part_lens( + self._multi_item_part_lens, + device=self.seq_lens_cuda.device) + else: + self._multi_item_params = None + self._multi_item_params_needs_refresh = False + if self.kv_cache_manager is None: assert self.request_ids is not None assert self.num_generations == 0, ( @@ -848,15 +871,10 @@ def prepare(self, n = self.num_seqs self._cached_token_lens[:n].zero_() self.num_ctx_cached_tokens = 0 - if multi_item_part_lens is not None: - self._multi_item_params = self._process_multi_item_part_lens( - multi_item_part_lens, device=self.seq_lens_cuda.device) - else: - self._multi_item_params = None self._clean_cached_plans(defer_plan=False) return - if multi_item_part_lens is not None: + if self._multi_item_params is not None: raise ValueError( "multi_item_part_lens with KV cache is not supported") diff --git a/tensorrt_llm/_torch/attention_backend/interface.py b/tensorrt_llm/_torch/attention_backend/interface.py index d80f75ab2b2d..ffb0eaf11b2c 100644 --- a/tensorrt_llm/_torch/attention_backend/interface.py +++ b/tensorrt_llm/_torch/attention_backend/interface.py @@ -308,7 +308,24 @@ def num_ctx_tokens(self) -> int: def num_tokens(self) -> int: return self._num_tokens - def prepare(self, *, multi_item_part_lens: list[list[int]] | None = None): + @property + def multi_item_part_lens(self) -> Optional[list[list[int]]]: + """Additional token layout information for multi-item scoring. + + Aggregates `TokensPrompt.multi_item_part_lens` for all requests in the batch, + see `TokensPrompt` for details. + """ + return None + + @multi_item_part_lens.setter + def multi_item_part_lens(self, + multi_item_part_lens: Optional[list[list[int]]]): + if multi_item_part_lens is not None: + raise ValueError( + "The selected attention backend does not support multi-item scoring." + ) + + def prepare(self): """ Hook to be called before the forward step of the model. """ diff --git a/tensorrt_llm/_torch/attention_backend/star_flashinfer.py b/tensorrt_llm/_torch/attention_backend/star_flashinfer.py index 23fec0a898b8..c6fcd66cbb06 100644 --- a/tensorrt_llm/_torch/attention_backend/star_flashinfer.py +++ b/tensorrt_llm/_torch/attention_backend/star_flashinfer.py @@ -87,13 +87,7 @@ def num_generations(self) -> int: """ return self.seq_lens.shape[0] - self.num_contexts - self.num_queries - def prepare(self, - *, - multi_item_part_lens: list[list[int]] | None = None) -> None: - if multi_item_part_lens is not None: - raise ValueError( - "Star Attention does not support multi-item scoring") - + def prepare(self) -> None: context_lens = self.context_lens query_lens = self.query_lens # indices of used cache blocks for each sequence diff --git a/tensorrt_llm/_torch/attention_backend/trtllm.py b/tensorrt_llm/_torch/attention_backend/trtllm.py index 7d839a3e2909..1a5b399d4e72 100644 --- a/tensorrt_llm/_torch/attention_backend/trtllm.py +++ b/tensorrt_llm/_torch/attention_backend/trtllm.py @@ -489,13 +489,8 @@ def update_helix_param( self.helix_is_inactive_rank[:batch_size].copy_( self.helix_is_inactive_rank_cpu[:batch_size], non_blocking=True) - def prepare(self, - *, - multi_item_part_lens: list[list[int]] | None = None) -> None: - if multi_item_part_lens is not None: - raise ValueError( - "TRT-LLM Attention does not support multi-item scoring") - super().prepare(multi_item_part_lens=multi_item_part_lens) + def prepare(self) -> None: + super().prepare() extra_attrs = get_model_extra_attrs() # If model extra attrs is set, attention_metadata is setup in executor. if extra_attrs is None: @@ -594,15 +589,8 @@ def prepare(self, self.host_request_types_runtime = self.host_request_types[:self. num_seqs] - def prepare_encoder_only( - self, - *, - multi_item_part_lens: list[list[int]] | None = None) -> None: + def prepare_encoder_only(self) -> None: """Fast path for encoder-only forward (eager + CUDA graph capture).""" - if multi_item_part_lens is not None: - raise ValueError( - "TRT-LLM Attention does not support multi-item scoring") - extra_attrs = get_model_extra_attrs() if extra_attrs is None: get_global_attrs().attention_metadata = weakref.ref(self) diff --git a/tensorrt_llm/_torch/attention_backend/vanilla.py b/tensorrt_llm/_torch/attention_backend/vanilla.py index 7155d248b45e..a65f900dd95b 100644 --- a/tensorrt_llm/_torch/attention_backend/vanilla.py +++ b/tensorrt_llm/_torch/attention_backend/vanilla.py @@ -60,13 +60,8 @@ def generate_sliding_window_mask(batch_size: int, target_length: int, class VanillaAttentionMetadata(AttentionMetadata): - def prepare(self, - *, - multi_item_part_lens: list[list[int]] | None = None) -> None: - if multi_item_part_lens is not None: - raise ValueError( - "Vanilla Attention does not support multi-item scoring") - super().prepare(multi_item_part_lens=multi_item_part_lens) + def prepare(self) -> None: + super().prepare() # indices of used cache blocks for each sequence assert self.request_ids is not None self.block_ids_per_seq = self.kv_cache_manager.get_batch_cache_indices( diff --git a/tensorrt_llm/_torch/pyexecutor/model_engine.py b/tensorrt_llm/_torch/pyexecutor/model_engine.py index dc2e2b82a107..1452bbf7ff73 100644 --- a/tensorrt_llm/_torch/pyexecutor/model_engine.py +++ b/tensorrt_llm/_torch/pyexecutor/model_engine.py @@ -14,6 +14,7 @@ import torch._dynamo.config import tensorrt_llm.bindings.internal.userbuffers as ub +from tensorrt_llm._torch.pyexecutor.sampling_utils import torch_multi_arange from tensorrt_llm._utils import (is_trace_enabled, maybe_pin_memory, nvtx_range, prefer_pinned, release_gc, torch_dtype_to_str, trace_func) @@ -4510,11 +4511,49 @@ def _prepare_encoder_inputs( dtype=torch.int, pin_memory=prefer_pinned()) if position_ids is None: - # Auto-generate packed position IDs: [0..n1-1, 0..n2-1, ...] - position_ids_t = torch.cat([ - torch.arange(s, dtype=torch.int) for s in seq_lens - ])[:actual_num_tokens] - position_ids_t = maybe_pin_memory(position_ids_t) + if multi_item_part_lens is not None: + if len(multi_item_part_lens) != len(seq_lens): + raise ValueError( + "\"multi_item_part_lens\" must either be provided for all prompts or for none" + ) + + # Scoring items have overlapping position IDs. Position IDs of delimiters + # are irrelevant. + starts_cuda = torch.tensor( + [ + start + for multi_item_part_lens in multi_item_part_lens + for start in [0] + [multi_item_part_lens[0]] * + (len(multi_item_part_lens) - 1) + ], + pin_memory=prefer_pinned(), + dtype=torch.int32, + ).to(device=self.position_ids_cuda.device, + non_blocking=True) + ends_cuda = torch.tensor( + [ + end + 1 + for multi_item_part_lens in multi_item_part_lens + for end in [multi_item_part_lens[0]] + [ + multi_item_part_lens[0] + item_len + for item_len in multi_item_part_lens[1:] + ] + ], + pin_memory=prefer_pinned(), + dtype=torch.int32, + ).to(device=self.position_ids_cuda.device, + non_blocking=True) + position_ids_t = torch_multi_arange( + starts=starts_cuda, + ends=ends_cuda, + output_length=input_ids_t.numel(), + ) + else: + # Auto-generate packed position IDs: [0..n1-1, 0..n2-1, ...] + position_ids_t = torch.cat([ + torch.arange(s, dtype=torch.int) for s in seq_lens + ])[:actual_num_tokens] + position_ids_t = maybe_pin_memory(position_ids_t) elif not isinstance(position_ids, torch.Tensor): position_ids_t = torch.tensor(position_ids, dtype=torch.int, @@ -4527,11 +4566,11 @@ def _prepare_encoder_inputs( attn_metadata.num_contexts = batch_size attn_metadata.max_seq_len = self.max_seq_len attn_metadata.request_ids = list(range(batch_size)) + attn_metadata.multi_item_part_lens = multi_item_part_lens if hasattr(attn_metadata, 'prepare_encoder_only'): - attn_metadata.prepare_encoder_only( - multi_item_part_lens=multi_item_part_lens) + attn_metadata.prepare_encoder_only() else: - attn_metadata.prepare(multi_item_part_lens=multi_item_part_lens) + attn_metadata.prepare() self.input_ids_cuda[:actual_num_tokens].copy_(input_ids_t, non_blocking=True) @@ -4551,7 +4590,8 @@ def _prepare_encoder_inputs( assert self.encoder_cuda_graph_runner.enabled, "Encoder CUDA graph runner is not enabled" # NB: The multi-item scoring arguments lack '_buf' counterparts (cf., e.g., - # https://github.com/flashinfer-ai/flashinfer/blob/2aa1d49cf140d73ccdd3761051c5f2944406cb83/flashinfer/prefill.py#L1622 ). + # https://github.com/flashinfer-ai/flashinfer/blob/2aa1d49cf140d73ccdd3761051c5f2944406cb83/flashinfer/prefill.py#L1622 ), + # which are typically used to support CUDA graphs in FlashInfer. assert multi_item_part_lens is None, "multi-item scoring with CUDA graph not implemented" attn_metadata.prepare_encoder_cuda_graph_replay(seq_lens, diff --git a/tensorrt_llm/llmapi/llm.py b/tensorrt_llm/llmapi/llm.py index 741e18846af5..2659d46449e2 100644 --- a/tensorrt_llm/llmapi/llm.py +++ b/tensorrt_llm/llmapi/llm.py @@ -16,7 +16,6 @@ from tqdm import tqdm from transformers import PreTrainedTokenizerBase -from tensorrt_llm._torch.pyexecutor.sampling_utils import torch_multi_arange from tensorrt_llm._utils import mpi_disabled from tensorrt_llm.inputs.multimodal import (DisaggPrefillMultimodalInputs, MultimodalParams) @@ -24,7 +23,7 @@ from tensorrt_llm.llmapi import tracing from tensorrt_llm.metrics.enums import MetricNames -from .._utils import nvtx_range_debug, prefer_pinned +from .._utils import nvtx_range_debug from ..bindings import executor as tllm from ..bindings import steady_clock_now from ..builder import EngineConfig @@ -948,7 +947,6 @@ def preprocess( ) @set_api_status("prototype") - @torch.inference_mode() def encode( self, inputs: Union[PromptInputs, Sequence[PromptInputs]], @@ -1088,36 +1086,6 @@ def encode( ) forward_inputs["multi_item_part_lens"] = batch_multi_item_part_lens - # Scoring items have overlapping position IDs. Position IDs of delimiters - # are irrelevant. - starts_cuda = torch.tensor( - [ - start for multi_item_part_lens in batch_multi_item_part_lens - for start in [0] + [multi_item_part_lens[0]] * - (len(multi_item_part_lens) - 1) - ], - pin_memory=prefer_pinned(), - dtype=torch.int32, - ).to("cuda", non_blocking=True) # uses current device - ends_cuda = torch.tensor( - [ - end + 1 - for multi_item_part_lens in batch_multi_item_part_lens - for end in [multi_item_part_lens[0]] + [ - multi_item_part_lens[0] + item_len - for item_len in multi_item_part_lens[1:] - ] - ], - pin_memory=prefer_pinned(), - dtype=torch.int32, - ).to("cuda", non_blocking=True) - position_ids_cuda = torch_multi_arange( - starts=starts_cuda, - ends=ends_cuda, - output_length=len(flat_token_ids), - ) - forward_inputs["position_ids"] = position_ids_cuda - # Single forward pass outputs = self._encoder_executor.batch_forward(forward_inputs, **forward_kwargs) diff --git a/tests/unittest/llmapi/test_llm_encode_multi_item.py b/tests/unittest/llmapi/test_llm_encode_multi_item.py index e77ff9490a67..c45c58ad2879 100644 --- a/tests/unittest/llmapi/test_llm_encode_multi_item.py +++ b/tests/unittest/llmapi/test_llm_encode_multi_item.py @@ -405,7 +405,7 @@ def test_unsupported_attention_backend( llm, pytest.raises( ValueError, - match=".*Attention does not support multi-item scoring.*", + match=".*The selected attention backend does not support multi-item scoring.*", ), ): llm.encode(prompt_inputs, batch_indexed_model_output=False) From 8380147c546da2d4675e5f81124e0280450587c3 Mon Sep 17 00:00:00 2001 From: ixlmar <206748156+ixlmar@users.noreply.github.com> Date: Wed, 17 Jun 2026 09:27:43 +0000 Subject: [PATCH 05/10] fix iteration variable name Signed-off-by: ixlmar <206748156+ixlmar@users.noreply.github.com> --- tensorrt_llm/_torch/pyexecutor/model_engine.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/tensorrt_llm/_torch/pyexecutor/model_engine.py b/tensorrt_llm/_torch/pyexecutor/model_engine.py index 1452bbf7ff73..2b7c651128a3 100644 --- a/tensorrt_llm/_torch/pyexecutor/model_engine.py +++ b/tensorrt_llm/_torch/pyexecutor/model_engine.py @@ -4522,9 +4522,9 @@ def _prepare_encoder_inputs( starts_cuda = torch.tensor( [ start - for multi_item_part_lens in multi_item_part_lens - for start in [0] + [multi_item_part_lens[0]] * - (len(multi_item_part_lens) - 1) + for req_multi_item_part_lens in multi_item_part_lens + for start in [0] + [req_multi_item_part_lens[0]] * + (len(req_multi_item_part_lens) - 1) ], pin_memory=prefer_pinned(), dtype=torch.int32, @@ -4533,10 +4533,10 @@ def _prepare_encoder_inputs( ends_cuda = torch.tensor( [ end + 1 - for multi_item_part_lens in multi_item_part_lens - for end in [multi_item_part_lens[0]] + [ - multi_item_part_lens[0] + item_len - for item_len in multi_item_part_lens[1:] + for req_multi_item_part_lens in multi_item_part_lens + for end in [req_multi_item_part_lens[0]] + [ + req_multi_item_part_lens[0] + item_len + for item_len in req_multi_item_part_lens[1:] ] ], pin_memory=prefer_pinned(), From aa2427e870973553b45d0af2dd09591d70990292 Mon Sep 17 00:00:00 2001 From: ixlmar <206748156+ixlmar@users.noreply.github.com> Date: Wed, 17 Jun 2026 09:49:21 +0000 Subject: [PATCH 06/10] remove FlashInferAttentionMetadata._multi_item_params_needs_refresh Signed-off-by: ixlmar <206748156+ixlmar@users.noreply.github.com> --- .../_torch/attention_backend/flashinfer.py | 16 +++++----------- 1 file changed, 5 insertions(+), 11 deletions(-) diff --git a/tensorrt_llm/_torch/attention_backend/flashinfer.py b/tensorrt_llm/_torch/attention_backend/flashinfer.py index e0516d79f454..3a423f2db09c 100644 --- a/tensorrt_llm/_torch/attention_backend/flashinfer.py +++ b/tensorrt_llm/_torch/attention_backend/flashinfer.py @@ -162,7 +162,6 @@ class FlashInferAttentionMetadata(AttentionMetadata): default=None) _multi_item_params: Optional[FlashInferMultiItemParams] = field( init=False, default=None) - _multi_item_params_needs_refresh: bool = field(init=False, default=False) def needs_plan(self, plan_params: PlanParams) -> bool: if plan_params not in self._plan_params_to_wrappers: @@ -604,7 +603,6 @@ def _post_init_with_buffers(self, buffers) -> None: self._mla_decode_planned = False self._multi_item_part_lens = None self._multi_item_params = None - self._multi_item_params_needs_refresh = False def create_cuda_graph_metadata(self, max_batch_size: int, @@ -835,7 +833,6 @@ def multi_item_part_lens(self) -> Optional[list[list[int]]]: def multi_item_part_lens(self, multi_item_part_lens: Optional[list[list[int]]]): self._multi_item_part_lens = multi_item_part_lens - self._multi_item_params_needs_refresh = True def prepare(self) -> None: super().prepare() @@ -849,14 +846,11 @@ def prepare(self) -> None: dtype=torch.int32, out=self._qo_indptr[1:self.seq_lens_cuda.size(0) + 1]) - if self._multi_item_params_needs_refresh: - if self._multi_item_part_lens is not None: - self._multi_item_params = self._process_multi_item_part_lens( - self._multi_item_part_lens, - device=self.seq_lens_cuda.device) - else: - self._multi_item_params = None - self._multi_item_params_needs_refresh = False + if self._multi_item_part_lens is not None: + self._multi_item_params = self._process_multi_item_part_lens( + self._multi_item_part_lens, device=self.seq_lens_cuda.device) + else: + self._multi_item_params = None if self.kv_cache_manager is None: assert self.request_ids is not None From b9aa2421ea8ec5289d39d8552b8a568800133678 Mon Sep 17 00:00:00 2001 From: ixlmar <206748156+ixlmar@users.noreply.github.com> Date: Wed, 17 Jun 2026 12:01:42 +0000 Subject: [PATCH 07/10] make AttentionMetadata.multi_item_part_lens a plain attribute Signed-off-by: ixlmar <206748156+ixlmar@users.noreply.github.com> --- .../_torch/attention_backend/flashinfer.py | 24 ++++++---------- .../_torch/attention_backend/interface.py | 28 ++++++++----------- .../_torch/pyexecutor/model_engine.py | 5 ++++ 3 files changed, 25 insertions(+), 32 deletions(-) diff --git a/tensorrt_llm/_torch/attention_backend/flashinfer.py b/tensorrt_llm/_torch/attention_backend/flashinfer.py index 3a423f2db09c..8da6e724c976 100644 --- a/tensorrt_llm/_torch/attention_backend/flashinfer.py +++ b/tensorrt_llm/_torch/attention_backend/flashinfer.py @@ -3,7 +3,8 @@ import os import weakref from dataclasses import dataclass, field -from typing import Any, Dict, Literal, NewType, Optional, TypeAlias, cast +from typing import (Any, Dict, Literal, NewType, Optional, TypeAlias, cast, + override) import flashinfer import torch @@ -158,8 +159,6 @@ class FlashInferAttentionMetadata(AttentionMetadata): _mla_kv_len_arr_buf: Optional[torch.Tensor] = field(init=False, default=None) - _multi_item_part_lens: Optional[list[list[int]]] = field(init=False, - default=None) _multi_item_params: Optional[FlashInferMultiItemParams] = field( init=False, default=None) @@ -601,7 +600,6 @@ def _post_init_with_buffers(self, buffers) -> None: self._mla_ragged_planned = False self._mla_context_planned = False self._mla_decode_planned = False - self._multi_item_part_lens = None self._multi_item_params = None def create_cuda_graph_metadata(self, @@ -825,15 +823,6 @@ def _clean_cached_plans(self, *, defer_plan: bool): else: del self._plan_params_to_wrappers[plan_params] - @property - def multi_item_part_lens(self) -> Optional[list[list[int]]]: - return self._multi_item_part_lens - - @multi_item_part_lens.setter - def multi_item_part_lens(self, - multi_item_part_lens: Optional[list[list[int]]]): - self._multi_item_part_lens = multi_item_part_lens - def prepare(self) -> None: super().prepare() extra_attrs = get_model_extra_attrs() @@ -846,9 +835,9 @@ def prepare(self) -> None: dtype=torch.int32, out=self._qo_indptr[1:self.seq_lens_cuda.size(0) + 1]) - if self._multi_item_part_lens is not None: + if self.multi_item_part_lens is not None: self._multi_item_params = self._process_multi_item_part_lens( - self._multi_item_part_lens, device=self.seq_lens_cuda.device) + self.multi_item_part_lens, device=self.seq_lens_cuda.device) else: self._multi_item_params = None @@ -1324,6 +1313,11 @@ class FlashInferAttention(AttentionBackend[FlashInferAttentionMetadata]): def support_mla(cls) -> bool: return True + @override + @classmethod + def support_multi_item_scoring(cls) -> bool: + return True + def __init__( self, layer_idx: int, diff --git a/tensorrt_llm/_torch/attention_backend/interface.py b/tensorrt_llm/_torch/attention_backend/interface.py index ffb0eaf11b2c..9ea77182ebb4 100644 --- a/tensorrt_llm/_torch/attention_backend/interface.py +++ b/tensorrt_llm/_torch/attention_backend/interface.py @@ -174,6 +174,13 @@ class AttentionMetadata: # The number of heads per kv head. num_heads_per_kv: Optional[int] = 1 + multi_item_part_lens: Optional[list[list[int]]] = None + """Additional token layout information for multi-item scoring. + + Aggregates `TokensPrompt.multi_item_part_lens` for all requests in the batch, + see `TokensPrompt` for details. + """ + def __post_init__(self) -> None: if self.is_cross: assert self.cross is None or self.cross is self, "Cross attention metadata should not have sub metadata" @@ -308,23 +315,6 @@ def num_ctx_tokens(self) -> int: def num_tokens(self) -> int: return self._num_tokens - @property - def multi_item_part_lens(self) -> Optional[list[list[int]]]: - """Additional token layout information for multi-item scoring. - - Aggregates `TokensPrompt.multi_item_part_lens` for all requests in the batch, - see `TokensPrompt` for details. - """ - return None - - @multi_item_part_lens.setter - def multi_item_part_lens(self, - multi_item_part_lens: Optional[list[list[int]]]): - if multi_item_part_lens is not None: - raise ValueError( - "The selected attention backend does not support multi-item scoring." - ) - def prepare(self): """ Hook to be called before the forward step of the model. @@ -994,6 +984,10 @@ def support_fused_qkv(cls) -> bool: def support_mla(cls) -> bool: return False + @classmethod + def support_multi_item_scoring(cls) -> bool: + return False + def create_output(self, q: torch.Tensor, **kwargs) -> List[torch.Tensor]: """ Create the output tensors for the attention operation. diff --git a/tensorrt_llm/_torch/pyexecutor/model_engine.py b/tensorrt_llm/_torch/pyexecutor/model_engine.py index 2b7c651128a3..536dbaf5277d 100644 --- a/tensorrt_llm/_torch/pyexecutor/model_engine.py +++ b/tensorrt_llm/_torch/pyexecutor/model_engine.py @@ -4566,6 +4566,11 @@ def _prepare_encoder_inputs( attn_metadata.num_contexts = batch_size attn_metadata.max_seq_len = self.max_seq_len attn_metadata.request_ids = list(range(batch_size)) + if multi_item_part_lens is not None and not self.attn_backend.support_multi_item_scoring( + ): + raise ValueError( + "The selected attention backend does not support multi-item scoring." + ) attn_metadata.multi_item_part_lens = multi_item_part_lens if hasattr(attn_metadata, 'prepare_encoder_only'): attn_metadata.prepare_encoder_only() From 48cd0d188cbd8bd8b6c4bc5ebc2b73d9944a8baa Mon Sep 17 00:00:00 2001 From: ixlmar <206748156+ixlmar@users.noreply.github.com> Date: Wed, 17 Jun 2026 15:26:48 +0000 Subject: [PATCH 08/10] fix for Python < 3.12 Signed-off-by: ixlmar <206748156+ixlmar@users.noreply.github.com> --- tensorrt_llm/_torch/attention_backend/flashinfer.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/tensorrt_llm/_torch/attention_backend/flashinfer.py b/tensorrt_llm/_torch/attention_backend/flashinfer.py index 8da6e724c976..77861d3c786e 100644 --- a/tensorrt_llm/_torch/attention_backend/flashinfer.py +++ b/tensorrt_llm/_torch/attention_backend/flashinfer.py @@ -1,10 +1,15 @@ import functools import math import os +import sys import weakref from dataclasses import dataclass, field -from typing import (Any, Dict, Literal, NewType, Optional, TypeAlias, cast, - override) +from typing import Any, Dict, Literal, NewType, Optional, TypeAlias, cast + +if sys.version_info[:2] >= (3, 12): + from typing import override +else: + from typing_extensions import override import flashinfer import torch From 0810e3cb05f69a9e487b28b513e76fcf12cb6983 Mon Sep 17 00:00:00 2001 From: ixlmar <206748156+ixlmar@users.noreply.github.com> Date: Mon, 22 Jun 2026 11:19:40 +0200 Subject: [PATCH 09/10] exclude NumPy 2.5.0 Signed-off-by: ixlmar <206748156+ixlmar@users.noreply.github.com> --- constraints.txt | 2 ++ 1 file changed, 2 insertions(+) diff --git a/constraints.txt b/constraints.txt index 519a9a29f9fc..bceb59ff6031 100644 --- a/constraints.txt +++ b/constraints.txt @@ -17,3 +17,5 @@ mistune>=3.2.1 notebook>=7.5.6 # WAR against https://github.com/advisories/GHSA-rch3-82jr-f9w9 jupyter_server>=2.18.0 +# NumPy 2.5.0 breaks type checking in CI +numpy<2.5.0 From 90103f59abc8d55f25a12ad36407637131b5c39c Mon Sep 17 00:00:00 2001 From: ixlmar <206748156+ixlmar@users.noreply.github.com> Date: Mon, 22 Jun 2026 13:58:08 +0200 Subject: [PATCH 10/10] Revert "exclude NumPy 2.5.0" This reverts commit 0810e3cb05f69a9e487b28b513e76fcf12cb6983. Signed-off-by: ixlmar <206748156+ixlmar@users.noreply.github.com> --- constraints.txt | 2 -- 1 file changed, 2 deletions(-) diff --git a/constraints.txt b/constraints.txt index bceb59ff6031..519a9a29f9fc 100644 --- a/constraints.txt +++ b/constraints.txt @@ -17,5 +17,3 @@ mistune>=3.2.1 notebook>=7.5.6 # WAR against https://github.com/advisories/GHSA-rch3-82jr-f9w9 jupyter_server>=2.18.0 -# NumPy 2.5.0 breaks type checking in CI -numpy<2.5.0