diff --git a/tensorrt_llm/_torch/attention_backend/flashinfer.py b/tensorrt_llm/_torch/attention_backend/flashinfer.py index 7ced64dec19a..77861d3c786e 100644 --- a/tensorrt_llm/_torch/attention_backend/flashinfer.py +++ b/tensorrt_llm/_torch/attention_backend/flashinfer.py @@ -1,16 +1,23 @@ import functools import math import os +import sys import weakref from dataclasses import dataclass, field from typing import Any, Dict, Literal, NewType, Optional, TypeAlias, cast +if sys.version_info[:2] >= (3, 12): + from typing import override +else: + from typing_extensions import override + import flashinfer import torch from flashinfer.jit.core import check_cuda_arch from typing_extensions import Self from tensorrt_llm._torch.pyexecutor.sampling_utils import torch_multi_arange +from tensorrt_llm._utils import nvtx_range from tensorrt_llm.functional import AttentionMaskType from tensorrt_llm.logger import logger from tensorrt_llm.models.modeling_utils import QuantConfig @@ -157,6 +164,9 @@ class FlashInferAttentionMetadata(AttentionMetadata): _mla_kv_len_arr_buf: Optional[torch.Tensor] = field(init=False, default=None) + _multi_item_params: Optional[FlashInferMultiItemParams] = field( + init=False, default=None) + def needs_plan(self, plan_params: PlanParams) -> bool: if plan_params not in self._plan_params_to_wrappers: return True @@ -595,6 +605,7 @@ def _post_init_with_buffers(self, buffers) -> None: self._mla_ragged_planned = False self._mla_context_planned = False self._mla_decode_planned = False + self._multi_item_params = None def create_cuda_graph_metadata(self, max_batch_size: int, @@ -722,6 +733,89 @@ def _plan_ragged_no_kv( **plan_kwargs, ) + def _process_multi_item_part_lens( + self, + multi_item_part_lens: list[list[int]], + *, + device: torch.device, + ) -> FlashInferMultiItemParams: + if self.num_generations > 0: + raise ValueError( + "\"multi_item_part_lens\" not supported for generation requests." + ) + if len(multi_item_part_lens) != self.num_contexts: + raise ValueError( + "\"multi_item_part_lens\" needs to be provided for all requests." + ) + + prefix_len_ptr = torch.tensor( + [req_part_lens[0] for req_part_lens in multi_item_part_lens], + pin_memory=prefer_pinned(), + dtype=torch.uint32, + ).to(device=device, non_blocking=True) + token_pos_in_items_raw_lens = [ # 'raw' lengths before padding + sum(req_part_lens[1:]) + len(req_part_lens) + for req_part_lens in multi_item_part_lens + ] + token_pos_in_items_len = max(token_pos_in_items_raw_lens) + max_item_len_ptr = torch.tensor( + [max(req_part_lens[1:]) for req_part_lens in multi_item_part_lens], + pin_memory=prefer_pinned(), + dtype=torch.uint16, + ).to(device=device, non_blocking=True) + + # token_pos_in_items_ptr is obtained by concatenating range(item_len + 1) for each item in + # every request, followed by [0] (final delimiter) which is fused with padding for simplicity. + range_ends = torch.tensor( + [ + item_len + 1 + for req_part_lens, token_pos_in_items_raw_len in zip( + multi_item_part_lens, token_pos_in_items_raw_lens) + for item_len in ( + req_part_lens[1:] + + [token_pos_in_items_len - token_pos_in_items_raw_len]) + ], + pin_memory=prefer_pinned(), + dtype=torch.int32, + ).to(device=device, non_blocking=True) + token_pos_in_items_ptr = torch_multi_arange( + range_ends, + output_length=(token_pos_in_items_len * len(multi_item_part_lens)), + ) + # next, mask out the padding + mask_entries = torch.arange(2, dtype=torch.uint8).to( + device=device, + non_blocking=True, + dtype=torch.bool, + ).repeat(len(multi_item_part_lens)) # NB: .expand() does not work here + mask_entry_repeats = torch.tensor( + [ + repeat + for token_pos_in_items_raw_len in token_pos_in_items_raw_lens + for repeat in [ + token_pos_in_items_raw_len, + token_pos_in_items_len - token_pos_in_items_raw_len, + ] + ], + pin_memory=prefer_pinned(), + dtype=torch.int32, + ).to(device=device, non_blocking=True) + padding_mask = torch.repeat_interleave( + input=mask_entries, + repeats=mask_entry_repeats, + output_size=token_pos_in_items_ptr.size(0), + ) + token_pos_in_items_ptr.masked_fill_(padding_mask, 0) + token_pos_in_items_ptr = token_pos_in_items_ptr.to(dtype=torch.uint16, + non_blocking=True) + + return FlashInferMultiItemParams( + prefix_len_ptr=prefix_len_ptr, + max_item_len_ptr=max_item_len_ptr, + token_pos_in_items_ptr=token_pos_in_items_ptr, + token_pos_in_items_len=token_pos_in_items_len, + ) + def _clean_cached_plans(self, *, defer_plan: bool): for plan_params in list(self._plan_params_to_wrappers.keys()): # Generally, plan_params with non-trivial attention masking are relevant only the @@ -740,11 +834,18 @@ def prepare(self) -> None: if extra_attrs is None: get_global_attrs().attention_metadata = weakref.ref(self) # start and end indices of each sequence in the ragged query + assert self.seq_lens_cuda is not None torch.cumsum(self.seq_lens_cuda, dim=0, dtype=torch.int32, out=self._qo_indptr[1:self.seq_lens_cuda.size(0) + 1]) + if self.multi_item_part_lens is not None: + self._multi_item_params = self._process_multi_item_part_lens( + self.multi_item_part_lens, device=self.seq_lens_cuda.device) + else: + self._multi_item_params = None + if self.kv_cache_manager is None: assert self.request_ids is not None assert self.num_generations == 0, ( @@ -761,6 +862,10 @@ def prepare(self) -> None: self._clean_cached_plans(defer_plan=False) return + if self._multi_item_params is not None: + raise ValueError( + "multi_item_part_lens with KV cache is not supported") + # indices of used cache blocks for each sequence assert self.request_ids is not None block_ids_per_seq = self.kv_cache_manager.get_batch_cache_indices( @@ -1009,7 +1114,6 @@ def plan(self, q_scaling: Optional[float] = None, attention_window_size: Optional[int] = None, attention_mask_data: Optional[torch.Tensor] = None, - multi_item_params: Optional[FlashInferMultiItemParams] = None, flashinfer_backend: str = "fa2") -> PlanParams: sm_scale = None @@ -1027,7 +1131,7 @@ def plan(self, if attention_window_size is not None else -1, attention_mask_type=AttentionMaskType(attention_mask_type), attention_mask_data=attention_mask_data, - multi_item_params=multi_item_params, + multi_item_params=self._multi_item_params, ) return self._plan_with_params(plan_params, flashinfer_backend) @@ -1214,6 +1318,11 @@ class FlashInferAttention(AttentionBackend[FlashInferAttentionMetadata]): def support_mla(cls) -> bool: return True + @override + @classmethod + def support_multi_item_scoring(cls) -> bool: + return True + def __init__( self, layer_idx: int, @@ -1247,90 +1356,6 @@ def update_quant_config(self, new_quant_config: Optional[QuantConfig]): self.has_fp8_kv_cache = self.quant_config.layer_quant_mode.has_fp8_kv_cache( ) - @staticmethod - def _process_multi_item_part_lens( - multi_item_part_lens: list[list[int]], - *, - metadata: FlashInferAttentionMetadata, - device: torch.device, - ) -> FlashInferMultiItemParams: - if metadata.num_generations > 0: - raise ValueError( - "\"multi_item_part_lens\" not supported for generation requests." - ) - if len(multi_item_part_lens) != metadata.num_contexts: - raise ValueError( - "\"multi_item_part_lens\" needs to be provided for all requests." - ) - - prefix_len_ptr = torch.tensor( - [req_part_lens[0] for req_part_lens in multi_item_part_lens], - pin_memory=prefer_pinned(), - dtype=torch.uint32, - ).to(device=device, non_blocking=True) - token_pos_in_items_raw_lens = [ # 'raw' lengths before padding - sum(req_part_lens[1:]) + len(req_part_lens) - for req_part_lens in multi_item_part_lens - ] - token_pos_in_items_len = max(token_pos_in_items_raw_lens) - max_item_len_ptr = torch.tensor( - [max(req_part_lens[1:]) for req_part_lens in multi_item_part_lens], - pin_memory=prefer_pinned(), - dtype=torch.uint16, - ).to(device=device, non_blocking=True) - - # token_pos_in_items_ptr is obtained by concatenating range(item_len + 1) for each item in - # every request, followed by [0] (final delimiter) which is fused with padding for simplicity. - range_ends = torch.tensor( - [ - item_len + 1 - for req_part_lens, token_pos_in_items_raw_len in zip( - multi_item_part_lens, token_pos_in_items_raw_lens) - for item_len in ( - req_part_lens[1:] + - [token_pos_in_items_len - token_pos_in_items_raw_len]) - ], - pin_memory=prefer_pinned(), - dtype=torch.int32, - ).to(device=device, non_blocking=True) - token_pos_in_items_ptr = torch_multi_arange( - range_ends, - output_length=(token_pos_in_items_len * len(multi_item_part_lens)), - ) - # next, mask out the padding - mask_entries = torch.arange(2, dtype=torch.uint8).to( - device=device, - non_blocking=True, - dtype=torch.bool, - ).repeat(len(multi_item_part_lens)) # NB: .expand() does not work here - mask_entry_repeats = torch.tensor( - [ - repeat - for token_pos_in_items_raw_len in token_pos_in_items_raw_lens - for repeat in [ - token_pos_in_items_raw_len, - token_pos_in_items_len - token_pos_in_items_raw_len, - ] - ], - pin_memory=prefer_pinned(), - dtype=torch.int32, - ).to(device=device, non_blocking=True) - padding_mask = torch.repeat_interleave( - input=mask_entries, - repeats=mask_entry_repeats, - output_size=token_pos_in_items_ptr.size(0), - ) - token_pos_in_items_ptr.masked_fill_(padding_mask, 0) - token_pos_in_items_ptr = token_pos_in_items_ptr.to(dtype=torch.uint16, - non_blocking=True) - - return FlashInferMultiItemParams( - prefix_len_ptr=prefix_len_ptr, - max_item_len_ptr=max_item_len_ptr, - token_pos_in_items_ptr=token_pos_in_items_ptr, - token_pos_in_items_len=token_pos_in_items_len, - ) - def mla_rope_generation( self, fused_q: torch.Tensor, @@ -1610,7 +1635,6 @@ def forward_impl( output: torch.Tensor, attention_mask_data: Optional[torch.Tensor] = None, attention_window_size: Optional[int] = None, - multi_item_part_lens: Optional[list[list[int]]] = None, latent_cache: Optional[torch.Tensor] = None, attention_input_type: AttentionInputType = AttentionInputType.mixed, ) -> None: @@ -1648,14 +1672,6 @@ def forward_impl( # Query q = q.view(-1, self.num_heads, self.head_dim) - multi_item_params: FlashInferMultiItemParams | None = None - if multi_item_part_lens is not None: - multi_item_params = self._process_multi_item_part_lens( - multi_item_part_lens, - metadata=metadata, - device=q.device, - ) - if metadata.kv_cache_manager is None: assert k is not None and v is not None, ( "FlashInfer without a KV cache manager requires key/value tensors." @@ -1670,18 +1686,18 @@ def forward_impl( assert v.shape == (q.size(0), self.num_kv_heads * self.head_dim) k = k.view(-1, self.num_kv_heads, self.head_dim) v = v.view(-1, self.num_kv_heads, self.head_dim) - plan_params = metadata.plan( - self.num_heads, - self.num_kv_heads, - self.head_dim, - q_dtype=q.dtype, - kv_dtype=k.dtype, - q_scaling=self.q_scaling, - attention_window_size=attention_window_size, - attention_mask_type=attention_mask_type, - attention_mask_data=attention_mask_data, - multi_item_params=multi_item_params, - ) + with nvtx_range("metadata.plan"): + plan_params = metadata.plan( + self.num_heads, + self.num_kv_heads, + self.head_dim, + q_dtype=q.dtype, + kv_dtype=k.dtype, + q_scaling=self.q_scaling, + attention_window_size=attention_window_size, + attention_mask_type=attention_mask_type, + attention_mask_data=attention_mask_data, + ) wrapper = metadata.get_ragged_prefill_wrapper(plan_params) if isinstance(wrapper, flashinfer.BatchPrefillWithPagedKVCacheWrapper): @@ -1700,11 +1716,6 @@ def forward_impl( ) return - if multi_item_part_lens is not None: - raise ValueError( - "Multi-item masking support not implemented for paged KV cache." - ) - # Key and Value kv_cache = metadata.kv_cache_manager.get_buffers( self.layer_idx, kv_layout=metadata.kv_layout) @@ -1901,6 +1912,5 @@ def forward(self, output=output, latent_cache=latent_cache, attention_input_type=forward_args.attention_input_type, - multi_item_part_lens=forward_args.multi_item_part_lens, ) return output diff --git a/tensorrt_llm/_torch/attention_backend/interface.py b/tensorrt_llm/_torch/attention_backend/interface.py index c1630bbf5041..9ea77182ebb4 100644 --- a/tensorrt_llm/_torch/attention_backend/interface.py +++ b/tensorrt_llm/_torch/attention_backend/interface.py @@ -174,6 +174,13 @@ class AttentionMetadata: # The number of heads per kv head. num_heads_per_kv: Optional[int] = 1 + multi_item_part_lens: Optional[list[list[int]]] = None + """Additional token layout information for multi-item scoring. + + Aggregates `TokensPrompt.multi_item_part_lens` for all requests in the batch, + see `TokensPrompt` for details. + """ + def __post_init__(self) -> None: if self.is_cross: assert self.cross is None or self.cross is self, "Cross attention metadata should not have sub metadata" @@ -835,13 +842,6 @@ class AttentionForwardArgs: relative_attention_max_distance: int = 0 cross_kv: Optional[torch.Tensor] = None - multi_item_part_lens: Optional[list[list[int]]] = None - """Additional token layout information for multi-item scoring. - - Aggregates `TokensPrompt.multi_item_part_lens` for all requests in the batch, - see `TokensPrompt` for details. - """ - latent_cache: Optional[torch.Tensor] = None q_pe: Optional[torch.Tensor] = None mrope_rotary_cos_sin: Optional[torch.Tensor] = None @@ -984,6 +984,10 @@ def support_fused_qkv(cls) -> bool: def support_mla(cls) -> bool: return False + @classmethod + def support_multi_item_scoring(cls) -> bool: + return False + def create_output(self, q: torch.Tensor, **kwargs) -> List[torch.Tensor]: """ Create the output tensors for the attention operation. diff --git a/tensorrt_llm/_torch/attention_backend/star_flashinfer.py b/tensorrt_llm/_torch/attention_backend/star_flashinfer.py index 8a396ab09fbc..c6fcd66cbb06 100644 --- a/tensorrt_llm/_torch/attention_backend/star_flashinfer.py +++ b/tensorrt_llm/_torch/attention_backend/star_flashinfer.py @@ -322,10 +322,6 @@ def forward(self, ) assert not metadata.is_cross, "Star Attention does not support cross attention yet." - if forward_args.multi_item_part_lens is not None: - raise ValueError( - "Star Attention does not support multi-item scoring") - q = q.view(-1, self.num_heads, self.head_dim) k = k.view(-1, self.num_kv_heads, self.head_dim) v = v.view(-1, self.num_kv_heads, self.head_dim) diff --git a/tensorrt_llm/_torch/attention_backend/trtllm.py b/tensorrt_llm/_torch/attention_backend/trtllm.py index 93e1a2bfe4c4..55b677d63812 100644 --- a/tensorrt_llm/_torch/attention_backend/trtllm.py +++ b/tensorrt_llm/_torch/attention_backend/trtllm.py @@ -1439,10 +1439,6 @@ def forward( # Cross-attention uses the THOP path; the trtllm-gen backend API does # not carry encoder K/V tensors yet. - if forward_args.multi_item_part_lens is not None: - raise ValueError( - "TRT-LLM Attention does not support multi-item scoring") - # SM90 forces ``use_paged_context_fmha`` on for correctness # (https://nvbugs/5624818). if get_sm_version() == 90: diff --git a/tensorrt_llm/_torch/attention_backend/vanilla.py b/tensorrt_llm/_torch/attention_backend/vanilla.py index 511dee84b743..a65f900dd95b 100644 --- a/tensorrt_llm/_torch/attention_backend/vanilla.py +++ b/tensorrt_llm/_torch/attention_backend/vanilla.py @@ -488,10 +488,6 @@ def forward(self, **kwargs) -> torch.Tensor: forward_args = merge_attention_forward_args(forward_args, kwargs) - if forward_args.multi_item_part_lens is not None: - raise ValueError( - "Vanilla Attention does not support multi-item scoring") - if metadata.kv_cache_manager is None: # NOTE: WAR for no kv cache attn e.g. BERT, # try to separate the kv cache estimation path from no kv cache attn. diff --git a/tensorrt_llm/_torch/modules/attention.py b/tensorrt_llm/_torch/modules/attention.py index 28d5cff5c3d8..cc0694728829 100644 --- a/tensorrt_llm/_torch/modules/attention.py +++ b/tensorrt_llm/_torch/modules/attention.py @@ -698,7 +698,6 @@ def _attn_impl( relative_attention_bias: Optional[torch.Tensor] = None, relative_attention_max_distance: int = 0, has_lora: bool = False, - multi_item_part_lens: Optional[list[list[int]]] = None, ): num_tokens = attn_metadata.num_tokens @@ -738,7 +737,6 @@ def _attn_impl( relative_attention_bias=relative_attention_bias, relative_attention_max_distance= relative_attention_max_distance, - multi_item_part_lens=multi_item_part_lens, )) if isinstance(attn_output, tuple): attn_output = attn_output[0] @@ -787,7 +785,6 @@ def _attn_impl( attention_sinks=attention_sinks, relative_attention_bias=relative_attention_bias, relative_attention_max_distance=relative_attention_max_distance, - multi_item_part_lens=multi_item_part_lens, )) if isinstance(attn_output, tuple): assert len( @@ -810,7 +807,6 @@ def forward_impl( relative_attention_bias: Optional[torch.Tensor] = None, relative_attention_max_distance: int = 0, has_lora: bool = False, - multi_item_part_lens: Optional[list[list[int]]] = None, ): mrope_rotary_cos_sin = None mrope_position_deltas = None @@ -863,7 +859,6 @@ def forward_impl( relative_attention_bias=relative_attention_bias, relative_attention_max_distance=relative_attention_max_distance, has_lora=has_lora, - multi_item_part_lens=multi_item_part_lens, ) if output_sf is not None: output = Fp4QuantizedTensor(output, output_sf) @@ -884,7 +879,6 @@ def forward( attention_sinks: Optional[torch.Tensor] = None, relative_attention_bias: Optional[torch.Tensor] = None, relative_attention_max_distance: int = 0, - multi_item_part_lens: Optional[list[list[int]]] = None, **kwargs, ) -> torch.Tensor: """ @@ -944,23 +938,6 @@ def forward( position_ids = self._adjust_position_ids_for_spec_dec( position_ids, attn_metadata) - if multi_item_part_lens is not None: - # adjust RoPE positions for multi-item scoring - current_idx = 0 - for req_multi_item_part_lens in multi_item_part_lens: - req_prefix_len, *req_multi_item_part_lens = req_multi_item_part_lens - # RoPE for prefix does not need updating and RoPE for delimiter does not matter - current_idx += req_prefix_len + 1 - for item_len in req_multi_item_part_lens: - next_idx = current_idx + item_len - position_ids[0, current_idx:next_idx].copy_( - torch.arange(req_prefix_len, - req_prefix_len + item_len, - dtype=position_ids.dtype, - device=position_ids.device), - non_blocking=True) - current_idx = next_idx + 1 # RoPE for delimiter does not matter - q, k, v = self.apply_rope(q, k, v, position_ids) q, k, v = self.convert_qkv(q, k, v) @@ -984,7 +961,6 @@ def forward( relative_attention_bias=relative_attention_bias, relative_attention_max_distance=relative_attention_max_distance, has_lora=bool(lora_params), - multi_item_part_lens=multi_item_part_lens, ) if self.attn_output_gate: diff --git a/tensorrt_llm/_torch/pyexecutor/cuda_graph_runner.py b/tensorrt_llm/_torch/pyexecutor/cuda_graph_runner.py index 478639035c62..77cf6edeea97 100644 --- a/tensorrt_llm/_torch/pyexecutor/cuda_graph_runner.py +++ b/tensorrt_llm/_torch/pyexecutor/cuda_graph_runner.py @@ -830,6 +830,14 @@ def maybe_get_cuda_graph( if not self._capture_allowed: return None, None + if "multi_item_part_lens" in inputs: + # See model_engine.py for more details + logger.warning_once( + "Encoder CUDA graph does not support multi-item scoring; " + "falling back to eager.", + key="encoder_cuda_graph_multi_item_scoring_warning") + return None, None + if attn_metadata.has_cross_sub_metadata: logger.warning_once( "Encoder CUDA graph does not support cross-attention metadata; " diff --git a/tensorrt_llm/_torch/pyexecutor/model_engine.py b/tensorrt_llm/_torch/pyexecutor/model_engine.py index ff1e5f8a5306..3a680cbb919a 100644 --- a/tensorrt_llm/_torch/pyexecutor/model_engine.py +++ b/tensorrt_llm/_torch/pyexecutor/model_engine.py @@ -14,6 +14,7 @@ import torch._dynamo.config import tensorrt_llm.bindings.internal.userbuffers as ub +from tensorrt_llm._torch.pyexecutor.sampling_utils import torch_multi_arange from tensorrt_llm._utils import (is_trace_enabled, maybe_pin_memory, nvtx_range, prefer_pinned, release_gc, torch_dtype_to_str, trace_func) @@ -4510,6 +4511,7 @@ def _prepare_encoder_inputs( input_ids = inputs['input_ids'] seq_lens = inputs['seq_lens'] # Only seq_lens includes padding position_ids = inputs.get('position_ids') + multi_item_part_lens = inputs.get('multi_item_part_lens') actual_num_tokens = len(input_ids) batch_size = len(seq_lens) @@ -4520,11 +4522,49 @@ def _prepare_encoder_inputs( dtype=torch.int, pin_memory=prefer_pinned()) if position_ids is None: - # Auto-generate packed position IDs: [0..n1-1, 0..n2-1, ...] - position_ids_t = torch.cat([ - torch.arange(s, dtype=torch.int) for s in seq_lens - ])[:actual_num_tokens] - position_ids_t = maybe_pin_memory(position_ids_t) + if multi_item_part_lens is not None: + if len(multi_item_part_lens) != len(seq_lens): + raise ValueError( + "\"multi_item_part_lens\" must either be provided for all prompts or for none" + ) + + # Scoring items have overlapping position IDs. Position IDs of delimiters + # are irrelevant. + starts_cuda = torch.tensor( + [ + start + for req_multi_item_part_lens in multi_item_part_lens + for start in [0] + [req_multi_item_part_lens[0]] * + (len(req_multi_item_part_lens) - 1) + ], + pin_memory=prefer_pinned(), + dtype=torch.int32, + ).to(device=self.position_ids_cuda.device, + non_blocking=True) + ends_cuda = torch.tensor( + [ + end + 1 + for req_multi_item_part_lens in multi_item_part_lens + for end in [req_multi_item_part_lens[0]] + [ + req_multi_item_part_lens[0] + item_len + for item_len in req_multi_item_part_lens[1:] + ] + ], + pin_memory=prefer_pinned(), + dtype=torch.int32, + ).to(device=self.position_ids_cuda.device, + non_blocking=True) + position_ids_t = torch_multi_arange( + starts=starts_cuda, + ends=ends_cuda, + output_length=input_ids_t.numel(), + ) + else: + # Auto-generate packed position IDs: [0..n1-1, 0..n2-1, ...] + position_ids_t = torch.cat([ + torch.arange(s, dtype=torch.int) for s in seq_lens + ])[:actual_num_tokens] + position_ids_t = maybe_pin_memory(position_ids_t) elif not isinstance(position_ids, torch.Tensor): position_ids_t = torch.tensor(position_ids, dtype=torch.int, @@ -4537,6 +4577,12 @@ def _prepare_encoder_inputs( attn_metadata.num_contexts = batch_size attn_metadata.max_seq_len = self.max_seq_len attn_metadata.request_ids = list(range(batch_size)) + if multi_item_part_lens is not None and not self.attn_backend.support_multi_item_scoring( + ): + raise ValueError( + "The selected attention backend does not support multi-item scoring." + ) + attn_metadata.multi_item_part_lens = multi_item_part_lens if hasattr(attn_metadata, 'prepare_encoder_only'): attn_metadata.prepare_encoder_only() else: @@ -4559,6 +4605,11 @@ def _prepare_encoder_inputs( # CUDA graph hit path. assert self.encoder_cuda_graph_runner.enabled, "Encoder CUDA graph runner is not enabled" + # NB: The multi-item scoring arguments lack '_buf' counterparts (cf., e.g., + # https://github.com/flashinfer-ai/flashinfer/blob/2aa1d49cf140d73ccdd3761051c5f2944406cb83/flashinfer/prefill.py#L1622 ), + # which are typically used to support CUDA graphs in FlashInfer. + assert multi_item_part_lens is None, "multi-item scoring with CUDA graph not implemented" + attn_metadata.prepare_encoder_cuda_graph_replay(seq_lens, padded_num_tokens) diff --git a/tests/unittest/llmapi/test_llm_encode_multi_item.py b/tests/unittest/llmapi/test_llm_encode_multi_item.py index e77ff9490a67..c45c58ad2879 100644 --- a/tests/unittest/llmapi/test_llm_encode_multi_item.py +++ b/tests/unittest/llmapi/test_llm_encode_multi_item.py @@ -405,7 +405,7 @@ def test_unsupported_attention_backend( llm, pytest.raises( ValueError, - match=".*Attention does not support multi-item scoring.*", + match=".*The selected attention backend does not support multi-item scoring.*", ), ): llm.encode(prompt_inputs, batch_indexed_model_output=False)