Skip to content

Commit e2c608d

Browse files
codex addressed review comments
Signed-off-by: Govind Ramnarayan <105831528+govind-ramnarayan@users.noreply.github.com>
1 parent 64e0739 commit e2c608d

13 files changed

Lines changed: 286 additions & 206 deletions

File tree

examples/auto_deploy/model_registry/configs/disagg_ctx.yaml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,4 +2,7 @@
22
# SPDX-License-Identifier: Apache-2.0
33
cache_transceiver_config:
44
backend: DEFAULT
5+
# Prefill workers must run without overlap scheduling because the current
6+
# context-only transfer path sends KV cache when a request completes its context
7+
# phase, after all prefill chunks have run.
58
disable_overlap_scheduler: true

tensorrt_llm/_torch/auto_deploy/custom_ops/attention/flashinfer_attention.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -589,6 +589,7 @@ def get_cache_initializers(
589589
dtype=cls.resolve_cache_dtype(cache_config.dtype, k_fake.dtype),
590590
kv_factor=2,
591591
kv_layout=_GlobalFlashInferPlanner.kv_layout,
592+
attention_type="mha",
592593
)
593594
}
594595

tensorrt_llm/_torch/auto_deploy/custom_ops/attention/triton_paged_attention.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1506,6 +1506,7 @@ def get_cache_initializers(
15061506
dtype=cls.resolve_cache_dtype(cache_config.dtype, k_fake.dtype),
15071507
kv_factor=2,
15081508
kv_layout=KV_LAYOUT,
1509+
attention_type="mha",
15091510
)
15101511
}
15111512

tensorrt_llm/_torch/auto_deploy/custom_ops/attention/trtllm_attention.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -877,6 +877,7 @@ def get_cache_initializers(
877877
dtype=cls.resolve_cache_dtype(cache_config.dtype, kv_dtype),
878878
kv_factor=2,
879879
kv_layout="HND",
880+
attention_type="mha",
880881
)
881882
}
882883

tensorrt_llm/_torch/auto_deploy/custom_ops/attention_interface.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@
3939
from ..utils.node_utils import extract_op_args, get_op_schema
4040

4141
Constant = Union[int, float, str, None]
42-
AttentionType = Literal["default", "mla"]
42+
AttentionType = Literal["mha", "mla"]
4343

4444
# Torch dtype → numpy dtype for fast list-to-tensor conversion.
4545
# numpy's list→array conversion is ~2-3x faster than torch.tensor(list) for large lists.
@@ -644,7 +644,7 @@ def __init__(
644644
# will store num_blocks later...
645645
self._num_blocks = None
646646

647-
self.attention_type: AttentionType = "default"
647+
self.attention_type: Optional[AttentionType] = None
648648

649649
# TODO (lucaslie): can we remove this eventually from this i/f?
650650
self.vocab_size_padded = vocab_size_padded
@@ -1582,7 +1582,7 @@ class KVPagedResourceHandler(ResourceHandler):
15821582
kv_factor: The factor of the KV cache. Default is 2 for combined k/v cache.
15831583
kv_layout: Memory layout for the KV cache. Either "HND" (head-num-dim) or "NHD" (num-head-dim).
15841584
Default is "HND" which is the standard layout for flashinfer.
1585-
attention_type: Attention semantics for this cache resource.
1585+
attention_type: Attention layout semantics for this cache resource: ``"mha"`` or ``"mla"``.
15861586
"""
15871587

15881588
@property
@@ -1595,9 +1595,9 @@ def __init__(
15951595
num_kv_heads: int,
15961596
head_dim: int,
15971597
dtype: torch.dtype,
1598+
attention_type: AttentionType,
15981599
kv_factor: int = 2,
15991600
kv_layout: Literal["HND", "NHD"] = "HND",
1600-
attention_type: AttentionType = "default",
16011601
) -> None:
16021602
"""Initialize the KVPagedResourceHandler.
16031603
@@ -1607,15 +1607,15 @@ def __init__(
16071607
dtype: The dtype of the KV cache.
16081608
kv_factor: The factor of the KV cache. Default is 2.
16091609
kv_layout: Memory layout - "HND" or "NHD". Default is "HND".
1610-
attention_type: Attention semantics for this cache resource.
1610+
attention_type: Attention layout semantics for this cache resource: ``"mha"`` or ``"mla"``.
16111611
"""
16121612
self.num_kv_heads = num_kv_heads
16131613
self.head_dim = head_dim
16141614
self.dtype = dtype
16151615
self.kv_factor = kv_factor
16161616
assert kv_factor in [1, 2], f"Invalid kv_factor: {kv_factor}"
16171617
self.kv_layout = kv_layout
1618-
assert attention_type in ["default", "mla"], f"Unsupported attention_type: {attention_type}"
1618+
assert attention_type in ["mha", "mla"], f"Unsupported attention_type: {attention_type}"
16191619
self.attention_type = attention_type
16201620

16211621
def __eq__(self, other: Optional[ResourceHandler]) -> bool:

tensorrt_llm/_torch/auto_deploy/shim/ad_executor.py

Lines changed: 21 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,7 @@
8080
from .interface import CachedSequenceInterface, GetInferenceModel
8181

8282
_ATTENTION_TYPE_TO_CPP = {
83-
"default": AttentionTypeCpp.DEFAULT,
83+
"mha": AttentionTypeCpp.DEFAULT,
8484
"mla": AttentionTypeCpp.MLA,
8585
}
8686

@@ -846,9 +846,10 @@ def _prepare_inputs(
846846
num_prefill_tokens = len(input_ids)
847847

848848
for request in gen_requests:
849-
# Overlap gathers tokens from the previous batch slot. Dummy padding
850-
# requests and first-step generation-only disagg requests do not have
851-
# a previous slot to gather from yet.
849+
# Overlap gathers tokens from the previous batch slot. Non-overlap
850+
# forwards do not pass new_tokens at all; first-step generation-only
851+
# disagg requests may have new_tokens from a previous batch but do
852+
# not have a previous batch slot (py_batch_idx) to gather from yet.
852853
is_overlap = (
853854
has_new_tokens
854855
and not self._disable_overlap_scheduler
@@ -1386,17 +1387,28 @@ def create_autodeploy_executor(
13861387
kv_cache_manager if isinstance(kv_cache_manager, BaseMambaCacheManager) else None
13871388
)
13881389
cache_transceiver_config = ad_config.cache_transceiver_config
1389-
if cache_transceiver_config and cache_transceiver_config.max_tokens_in_buffer is None:
1390+
cache_transceiver_enabled = (
1391+
cache_transceiver_config is not None and cache_transceiver_config.backend is not None
1392+
)
1393+
if cache_transceiver_enabled and cache_transceiver_config.max_tokens_in_buffer is None:
13901394
# The disagg transfer buffer must fit the full context segment handed to
13911395
# the generation worker. AutoDeploy's cache interface exposes the tuned
13921396
# maximum sequence length, which is a conservative upper bound.
13931397
cache_transceiver_config.max_tokens_in_buffer = engine.cache_seq_interface.info.max_seq_len
13941398

13951399
cache_attention_type = engine.cache_seq_interface.attention_type
1396-
try:
1397-
attention_type = _ATTENTION_TYPE_TO_CPP[cache_attention_type]
1398-
except KeyError as exc:
1399-
raise ValueError(f"Unsupported attention_type: {cache_attention_type!r}") from exc
1400+
if cache_transceiver_enabled and cache_attention_type is None:
1401+
raise RuntimeError(
1402+
"Cache transceiver is enabled, but AutoDeploy did not find a managed paged KV "
1403+
"resource to provide attention_type."
1404+
)
1405+
if cache_attention_type is None:
1406+
attention_type = AttentionTypeCpp.DEFAULT
1407+
else:
1408+
try:
1409+
attention_type = _ATTENTION_TYPE_TO_CPP[cache_attention_type]
1410+
except KeyError as exc:
1411+
raise ValueError(f"Unsupported attention_type: {cache_attention_type!r}") from exc
14001412

14011413
kv_cache_transceiver = create_kv_cache_transceiver(
14021414
dist_mapping,

tensorrt_llm/_torch/auto_deploy/shim/interface.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -699,7 +699,7 @@ def _create_kv_cache_manager(self, max_tokens: Optional[int] = None) -> Dict:
699699
"""
700700
# 1. Identify managed resources
701701
kv_ref, kv_managed = self._identify_managed_kv_resources()
702-
self.info.attention_type = kv_ref.attention_type if kv_ref is not None else "default"
702+
self.info.attention_type = kv_ref.attention_type if kv_ref is not None else None
703703
ssm_ref, ssm_managed, ssm_spec, conv_ref, conv_managed, conv_spec = (
704704
self._identify_managed_state_resources()
705705
)
@@ -939,7 +939,7 @@ def kv_cache_config(self) -> KvCacheConfig:
939939
return self._kv_cache_config_original
940940

941941
@property
942-
def attention_type(self) -> AttentionType:
942+
def attention_type(self) -> Optional[AttentionType]:
943943
return self.info.attention_type
944944

945945
def _clear_caches(self) -> None:

tests/integration/defs/disaggregated/test_disaggregated_single_gpu.py

Lines changed: 0 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,3 @@
1-
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2-
# SPDX-License-Identifier: Apache-2.0
3-
#
4-
# Licensed under the Apache License, Version 2.0 (the "License");
5-
# you may not use this file except in compliance with the License.
6-
# You may obtain a copy of the License at
7-
#
8-
# http://www.apache.org/licenses/LICENSE-2.0
9-
#
10-
# Unless required by applicable law or agreed to in writing, software
11-
# distributed under the License is distributed on an "AS IS" BASIS,
12-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13-
# See the License for the specific language governing permissions and
14-
# limitations under the License.
15-
161
import asyncio
172
import os
183
import pickle

tests/unittest/auto_deploy/singlegpu/custom_ops/test_resource_handlers.py

Lines changed: 28 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,9 @@
3131

3232
def test_paged_handler_with_nhd_layout():
3333
"""Test KVPagedResourceHandler with NHD layout."""
34-
handler = KVPagedResourceHandler(8, 64, dtype=torch.bfloat16, kv_layout="NHD")
34+
handler = KVPagedResourceHandler(
35+
8, 64, dtype=torch.bfloat16, kv_layout="NHD", attention_type="mha"
36+
)
3537
assert handler.num_kv_heads == 8
3638
assert handler.head_dim == 64
3739
assert handler.dtype == torch.bfloat16
@@ -40,7 +42,9 @@ def test_paged_handler_with_nhd_layout():
4042

4143
def test_paged_handler_with_hnd_layout():
4244
"""Test KVPagedResourceHandler with explicit HND layout."""
43-
handler = KVPagedResourceHandler(4, 128, dtype=torch.float32, kv_layout="HND")
45+
handler = KVPagedResourceHandler(
46+
4, 128, dtype=torch.float32, kv_layout="HND", attention_type="mha"
47+
)
4448
assert handler.num_kv_heads == 4
4549
assert handler.head_dim == 128
4650
assert handler.dtype == torch.float32
@@ -50,7 +54,9 @@ def test_paged_handler_with_hnd_layout():
5054
@pytest.mark.parametrize("kv_layout", ["HND", "NHD"])
5155
def test_paged_handler_allocate_with_blocks(kv_layout):
5256
"""Verify KVPagedResourceHandler.allocate() returns correct shape."""
53-
handler = KVPagedResourceHandler(8, 64, dtype=torch.float16, kv_layout=kv_layout)
57+
handler = KVPagedResourceHandler(
58+
8, 64, dtype=torch.float16, kv_layout=kv_layout, attention_type="mha"
59+
)
5460
tokens_per_block = 32
5561
seq_info = SequenceInfo(
5662
max_seq_len=128,
@@ -88,7 +94,7 @@ def test_paged_handler_allocate_with_blocks(kv_layout):
8894

8995
def test_paged_handler_is_resource_handler():
9096
"""Verify KVPagedResourceHandler is a ResourceHandler subclass."""
91-
handler = KVPagedResourceHandler(8, 64, dtype=torch.float16)
97+
handler = KVPagedResourceHandler(8, 64, dtype=torch.float16, attention_type="mha")
9298
assert isinstance(handler, ResourceHandler)
9399

94100

@@ -271,9 +277,13 @@ def test_resolve_cache_dtype_explicit_fp8():
271277

272278
def test_kv_paged_handler_eq_same_head_dim_dtype():
273279
"""Verify KVPagedResourceHandler __eq__ checks head_dim and dtype."""
274-
h1 = KVPagedResourceHandler(8, 64, dtype=torch.float16)
275-
h2 = KVPagedResourceHandler(4, 64, dtype=torch.float16) # Different num_kv_heads
276-
h3 = KVPagedResourceHandler(8, 64, dtype=torch.float16, kv_layout="NHD") # Different layout
280+
h1 = KVPagedResourceHandler(8, 64, dtype=torch.float16, attention_type="mha")
281+
h2 = KVPagedResourceHandler(
282+
4, 64, dtype=torch.float16, attention_type="mha"
283+
) # Different num_kv_heads
284+
h3 = KVPagedResourceHandler(
285+
8, 64, dtype=torch.float16, kv_layout="NHD", attention_type="mha"
286+
) # Different layout
277287

278288
# head_dim, kv_factor, dtype, kv_layout -> equal (num_kv_heads doesn't matter for compatibility)
279289
assert h1 == h2
@@ -282,22 +292,28 @@ def test_kv_paged_handler_eq_same_head_dim_dtype():
282292

283293
def test_kv_paged_handler_eq_different_head_dim_or_dtype():
284294
"""Verify KVPagedResourceHandler __eq__ returns False for different head_dim or dtype."""
285-
h1 = KVPagedResourceHandler(8, 64, dtype=torch.float16)
286-
h2 = KVPagedResourceHandler(8, 128, dtype=torch.float16) # Different head_dim
287-
h3 = KVPagedResourceHandler(8, 64, dtype=torch.bfloat16) # Different dtype
295+
h1 = KVPagedResourceHandler(8, 64, dtype=torch.float16, attention_type="mha")
296+
h2 = KVPagedResourceHandler(
297+
8, 128, dtype=torch.float16, attention_type="mha"
298+
) # Different head_dim
299+
h3 = KVPagedResourceHandler(
300+
8, 64, dtype=torch.bfloat16, attention_type="mha"
301+
) # Different dtype
288302

289303
assert h1 != h2
290304
assert h1 != h3
291305

292306

293307
def test_kv_paged_handler_eq_different_attention_type():
294308
"""Verify KVPagedResourceHandler __eq__ rejects different attention semantics."""
295-
default_handler = KVPagedResourceHandler(8, 64, dtype=torch.float16, kv_factor=1)
309+
default_handler = KVPagedResourceHandler(
310+
8, 64, dtype=torch.float16, kv_factor=1, attention_type="mha"
311+
)
296312
mla_handler = KVPagedResourceHandler(
297313
8, 64, dtype=torch.float16, kv_factor=1, attention_type="mla"
298314
)
299315

300-
assert default_handler.attention_type == "default"
316+
assert default_handler.attention_type == "mha"
301317
assert mla_handler.attention_type == "mla"
302318
assert default_handler != mla_handler
303319

0 commit comments

Comments
 (0)