Skip to content

Commit 228829c

Browse files
authored
[TRTLLM-12958][feat] Enable gen-only spec dec (#14546)
Signed-off-by: Bo Deng <deemod@nvidia.com>
1 parent 5e3f012 commit 228829c

6 files changed

Lines changed: 112 additions & 4 deletions

File tree

tensorrt_llm/_torch/disaggregation/native/peer.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -97,11 +97,14 @@ def _check_peer_compatible(self, peer_ri: RankInfo) -> bool:
9797
self_layers = sum(self._ri.layer_num_per_pp)
9898
peer_layers = sum(peer_ri.layer_num_per_pp)
9999
if self_layers != peer_layers:
100+
# Allow mismatch when one side has speculative (e.g. MTP) layers
101+
# that the other side doesn't. The pool_mapping logic will only
102+
# transfer layers that exist on both sides.
100103
logger.warning(
101-
"PeerRegistrar: total layer count mismatch "
102-
f"(local={self_layers}, peer={peer_layers})."
104+
"PeerRegistrar: layer count differs "
105+
f"(local={self_layers}, peer={peer_layers}), "
106+
"allowing partial layer transfer."
103107
)
104-
return False
105108

106109
return True
107110

tensorrt_llm/_torch/disaggregation/transceiver.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -178,11 +178,16 @@ def _create_kv_slice(
178178
groups.append(np.array([], dtype=np.int64))
179179
continue
180180
block_ids = adapter.get_block_ids(req, idx, lg)
181+
# Limit to prompt_len blocks, matching C++ cacheFormatter behavior.
182+
# Extra blocks from num_extra_kv_tokens (speculative decoding) have
183+
# uninitialized KV data and must not be transferred.
184+
total_blocks = (req.prompt_len + tpb - 1) // tpb
185+
if block_ids.size > total_blocks:
186+
block_ids = block_ids[:total_blocks]
181187
window_size = lg.sliding_window_size
182188

183189
if window_size is not None:
184190
# Drop stale blocks the manager may still expose (V1 pre-eviction).
185-
total_blocks = (req.prompt_len + tpb - 1) // tpb
186191
stale_end = max(0, (req.prompt_len + 1 - window_size) // tpb)
187192
expected_valid = max(0, total_blocks - stale_end)
188193
if block_ids.size > expected_valid:

tensorrt_llm/_torch/pyexecutor/py_executor.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4223,6 +4223,14 @@ def _prepare_disagg_gen_transmission_complete(self, scheduled_batch):
42234223
req.py_kv_transfer_timed_out = False
42244224
first_gen_tokens = req.context_phase_params.first_gen_tokens
42254225
ctx_draft_tokens = req.context_phase_params.draft_tokens
4226+
if not ctx_draft_tokens and self.model_engine.enable_spec_decode:
4227+
# CTX has no speculative decoding — fill dummy draft tokens
4228+
# so model_engine builds the correct input shape (1 + draft_len
4229+
# tokens per gen request). Dummies will be rejected on verify,
4230+
# and the draft model will produce real tokens for next step.
4231+
ctx_draft_tokens = [
4232+
0
4233+
] * self.model_engine.max_total_draft_tokens
42264234
req.py_draft_tokens = [] if ctx_draft_tokens is None else ctx_draft_tokens
42274235
beam_width = req.py_beam_width
42284236
for beam in range(0, beam_width):

tests/integration/defs/accuracy/test_disaggregated_serving.py

Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -780,6 +780,63 @@ def test_eagle3(self, overlap_scheduler, eagle3_one_model):
780780
self.MODEL_PATH) as llm:
781781
run_accuracy_test(llm, self.MODEL_NAME, ["GSM8K"])
782782

783+
@pytest.mark.skip_less_device(2)
784+
@skip_pre_hopper
785+
def test_gen_only_spec_dec(self):
786+
speculative_decoding_config = {
787+
"decoding_type": "Eagle",
788+
"max_draft_len": 4,
789+
"speculative_model":
790+
f"{llm_models_root()}/EAGLE3-LLaMA3.1-Instruct-8B",
791+
"eagle3_one_model": True,
792+
}
793+
ctx_server_config = {
794+
"disable_overlap_scheduler":
795+
True, # BS=1 does not need overlap scheduling
796+
"kv_cache_config": {
797+
"free_gpu_memory_fraction": 0.5,
798+
"enable_block_reuse": True # reuse on context requests
799+
},
800+
"max_num_tokens": 13393 * 2,
801+
"max_batch_size": 1,
802+
"cache_transceiver_config": {
803+
"backend": "NIXL",
804+
"transceiver_runtime": "PYTHON",
805+
"max_tokens_in_buffer": 4096,
806+
},
807+
"cuda_graph_config": None,
808+
}
809+
gen_server_config = {
810+
"disable_overlap_scheduler": False,
811+
"speculative_config": speculative_decoding_config,
812+
"kv_cache_config": {
813+
"free_gpu_memory_fraction": 0.5,
814+
"enable_block_reuse": False
815+
},
816+
"max_num_tokens": 13393 * 2,
817+
"max_batch_size": 16,
818+
"cache_transceiver_config": {
819+
"backend": "NIXL",
820+
"transceiver_runtime": "PYTHON",
821+
"max_tokens_in_buffer": 4096,
822+
},
823+
"cuda_graph_config": None,
824+
}
825+
disaggregated_server_config = {
826+
"hostname": "localhost",
827+
"backend": "pytorch",
828+
"context_servers": {
829+
"num_instances": 1
830+
},
831+
"generation_servers": {
832+
"num_instances": 1
833+
}
834+
}
835+
with launch_disaggregated_llm(disaggregated_server_config,
836+
ctx_server_config, gen_server_config,
837+
self.MODEL_PATH) as llm:
838+
run_accuracy_test(llm, self.MODEL_NAME, ["GSM8K"])
839+
783840
@pytest.mark.skip_less_device(2)
784841
@pytest.mark.skip_less_device_memory(32000)
785842
@pytest.mark.parametrize("backend", ["xgrammar", "llguidance"])
@@ -1001,6 +1058,39 @@ def test_gen_only_sync(self):
10011058
) as llm:
10021059
run_accuracy_test(llm, self.MODEL_NAME, ["GSM8K"])
10031060

1061+
@pytest.mark.skip_less_device(8)
1062+
@skip_pre_hopper
1063+
def test_gen_only_spec_dec(self):
1064+
ctx_server_config = {"disable_overlap_scheduler": True}
1065+
gen_server_config = {"disable_overlap_scheduler": False}
1066+
cache_transceiver_config = {
1067+
"backend": "NIXL",
1068+
"max_tokens_in_buffer": 4096,
1069+
"transceiver_runtime": "PYTHON",
1070+
}
1071+
ctx_server_config["cache_transceiver_config"] = cache_transceiver_config
1072+
gen_server_config["cache_transceiver_config"] = cache_transceiver_config
1073+
gen_server_config["speculative_config"] = {
1074+
"decoding_type": "MTP",
1075+
"max_draft_len": 2
1076+
}
1077+
disaggregated_server_config = {
1078+
"hostname": "localhost",
1079+
"backend": "pytorch",
1080+
"context_servers": {
1081+
"num_instances": 1
1082+
},
1083+
"generation_servers": {
1084+
"num_instances": 1
1085+
}
1086+
}
1087+
with launch_disaggregated_llm(disaggregated_server_config,
1088+
ctx_server_config,
1089+
gen_server_config,
1090+
self.MODEL_PATH,
1091+
tensor_parallel_size=4) as llm:
1092+
run_accuracy_test(llm, self.MODEL_NAME, ["MMLU", "GSM8K"])
1093+
10041094
@pytest.mark.skip_less_device(8)
10051095
@parametrize_with_ids("overlap_scheduler", [True, False])
10061096
@parametrize_with_ids("mtp_nextn", [0, 2])

tests/integration/test_lists/test-db/l0_dgx_h100.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,7 @@ l0_dgx_h100:
5656
- disaggregated/test_disaggregated.py::test_disaggregated_cancel_large_context_requests[DeepSeek-V3-Lite-bf16]
5757
# llmapi
5858
- unittest/llmapi/test_mpi_session.py::test_llmapi_launch_multiple_tasks
59+
- accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_gen_only_spec_dec
5960
- condition:
6061
ranges:
6162
system_gpu_count:

tests/integration/test_lists/test-db/l0_dgx_h200.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@ l0_dgx_h200:
4141
- disaggregated/test_disaggregated.py::test_disaggregated_ctxpp4_genpp4[TinyLlama-1.1B-Chat-v1.0]
4242
- disaggregated/test_disaggregated.py::test_disaggregated_deepseek_v3_lite_fp8_ctxtp2ep2pp2_gentp4_one_mtp_block_reuse[DeepSeek-V3-Lite-fp8]
4343
- unittest/llmapi/test_llm_pytorch.py::test_nemotron_nas_lora
44+
- accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_gen_only_spec_dec
4445
- condition:
4546
ranges:
4647
system_gpu_count:

0 commit comments

Comments
 (0)