Skip to content

Commit 08f4bb1

Browse files
authored
[None][fix] Fix stale sparse attention kwargs (#15460)
Signed-off-by: Bo Li <22713281+bobboli@users.noreply.github.com>
1 parent 79ea125 commit 08f4bb1

2 files changed

Lines changed: 6 additions & 2 deletions

File tree

tensorrt_llm/_torch/pyexecutor/_util.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1301,7 +1301,7 @@ def _create_cross_kv_cache_manager(
13011301
max_seq_len=max_seq_len,
13021302
max_batch_size=self._max_batch_size,
13031303
spec_config=None,
1304-
sparse_attn_config=None,
1304+
sparse_attention_config=None,
13051305
max_num_tokens=self._max_num_tokens,
13061306
max_beam_width=1,
13071307
kv_connector_manager=None,

tensorrt_llm/_torch/pyexecutor/model_engine.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5234,6 +5234,10 @@ def _prepare_tp_inputs_encoder(
52345234
# Build a fresh, no-cache attention metadata for the encoder
52355235
# pass. We do not reuse ``self.attn_metadata`` because that
52365236
# object is bound to the decoder's KV-cache manager.
5237+
sparse_metadata_params = (
5238+
self.sparse_attention_config.to_sparse_metadata_params(
5239+
pretrained_config=self.model.model_config.pretrained_config)
5240+
if self.sparse_attention_config is not None else None)
52375241
encoder_attn_metadata = self.attn_backend.Metadata(
52385242
max_num_requests=self.batch_size,
52395243
max_num_tokens=self.max_num_tokens,
@@ -5244,7 +5248,7 @@ def _prepare_tp_inputs_encoder(
52445248
enable_flash_mla=self.model.model_config.enable_flash_mla,
52455249
enable_context_mla_with_cached_kv=False,
52465250
cache_indirection=None,
5247-
sparse_attention_config=self.sparse_attention_config,
5251+
sparse_metadata_params=sparse_metadata_params,
52485252
num_heads_per_kv=1,
52495253
)
52505254
assert isinstance(

0 commit comments

Comments
 (0)