Skip to content

Commit 2cebbef

Browse files
committed
clean codes && add tests
Signed-off-by: Bo Deng <deemod@nvidia.com>
1 parent 2b57dd9 commit 2cebbef

5 files changed

Lines changed: 100 additions & 9 deletions

File tree

tensorrt_llm/_torch/disaggregation/native/peer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,7 @@ def _check_peer_compatible(self, peer_ri: RankInfo) -> bool:
100100
# Allow mismatch when one side has speculative (e.g. MTP) layers
101101
# that the other side doesn't. The pool_mapping logic will only
102102
# transfer layers that exist on both sides.
103-
logger.info(
103+
logger.warning(
104104
"PeerRegistrar: layer count differs "
105105
f"(local={self_layers}, peer={peer_layers}), "
106106
"allowing partial layer transfer."

tensorrt_llm/_torch/disaggregation/native/transfer.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1717,11 +1717,6 @@ def unpack_aux(self, request: LlmRequest) -> None:
17171717
self._aux_buffer.get_slot_data(self.aux_slot)
17181718
)
17191719
request.py_first_gen_tokens = first_gen_tokens # type: ignore[attr-defined]
1720-
# When CTX has no MTP but GEN does, draft_tokens will be empty.
1721-
# Pad with dummy zeros so GEN's MTP forward sees uniform draft_len
1722-
# across the batch. The dummy tokens will be rejected on first verify.
1723-
if not draft_tokens and self._aux_buffer._max_draft_len > 0:
1724-
draft_tokens = [0] * self._aux_buffer._max_draft_len
17251720
request.py_draft_tokens = draft_tokens # type: ignore[attr-defined]
17261721
if request.py_disaggregated_params is not None:
17271722
request.py_disaggregated_params.ctx_usage = {

tensorrt_llm/_torch/pyexecutor/py_executor.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3742,9 +3742,13 @@ def _prepare_disagg_gen_transmission_complete(self, scheduled_batch):
37423742
first_gen_tokens = req.context_phase_params.first_gen_tokens
37433743
ctx_draft_tokens = req.context_phase_params.draft_tokens
37443744
if not ctx_draft_tokens and self.model_engine.enable_spec_decode:
3745-
# CTX has no MTP — fill dummy draft tokens so GEN's MTP
3746-
# forward sees uniform draft_len. Dummies will be rejected.
3747-
ctx_draft_tokens = [0] * self.model_engine.max_draft_len
3745+
# CTX has no speculative decoding — fill dummy draft tokens
3746+
# so model_engine builds the correct input shape (1 + draft_len
3747+
# tokens per gen request). Dummies will be rejected on verify,
3748+
# and the draft model will produce real tokens for next step.
3749+
ctx_draft_tokens = [
3750+
0
3751+
] * self.model_engine.max_total_draft_tokens
37483752
req.py_draft_tokens = [] if ctx_draft_tokens is None else ctx_draft_tokens
37493753
beam_width = req.py_beam_width
37503754
for beam in range(0, beam_width):

tests/integration/defs/accuracy/test_disaggregated_serving.py

Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -780,6 +780,63 @@ def test_eagle3(self, overlap_scheduler, eagle3_one_model):
780780
self.MODEL_PATH) as llm:
781781
run_accuracy_test(llm, self.MODEL_NAME, ["GSM8K"])
782782

783+
@pytest.mark.skip_less_device(2)
784+
@skip_pre_hopper
785+
def test_gen_only_spec_dec(self):
786+
speculative_decoding_config = {
787+
"decoding_type": "Eagle",
788+
"max_draft_len": 4,
789+
"speculative_model":
790+
f"{llm_models_root()}/EAGLE3-LLaMA3.1-Instruct-8B",
791+
"eagle3_one_model": True,
792+
}
793+
ctx_server_config = {
794+
"disable_overlap_scheduler":
795+
True, # BS=1 does not need overlap scheduling
796+
"kv_cache_config": {
797+
"free_gpu_memory_fraction": 0.5,
798+
"enable_block_reuse": True # reuse on context requests
799+
},
800+
"max_num_tokens": 13393 * 2,
801+
"max_batch_size": 1,
802+
"cache_transceiver_config": {
803+
"backend": "NIXL",
804+
"transceiver_runtime": "PYTHON",
805+
"max_tokens_in_buffer": 4096,
806+
},
807+
"cuda_graph_config": None,
808+
}
809+
gen_server_config = {
810+
"disable_overlap_scheduler": False,
811+
"speculative_config": speculative_decoding_config,
812+
"kv_cache_config": {
813+
"free_gpu_memory_fraction": 0.5,
814+
"enable_block_reuse": False
815+
},
816+
"max_num_tokens": 13393 * 2,
817+
"max_batch_size": 16,
818+
"cache_transceiver_config": {
819+
"backend": "NIXL",
820+
"transceiver_runtime": "PYTHON",
821+
"max_tokens_in_buffer": 4096,
822+
},
823+
"cuda_graph_config": None,
824+
}
825+
disaggregated_server_config = {
826+
"hostname": "localhost",
827+
"backend": "pytorch",
828+
"context_servers": {
829+
"num_instances": 1
830+
},
831+
"generation_servers": {
832+
"num_instances": 1
833+
}
834+
}
835+
with launch_disaggregated_llm(disaggregated_server_config,
836+
ctx_server_config, gen_server_config,
837+
self.MODEL_PATH) as llm:
838+
run_accuracy_test(llm, self.MODEL_NAME, ["GSM8K"])
839+
783840
@pytest.mark.skip_less_device(2)
784841
@pytest.mark.skip_less_device_memory(32000)
785842
@pytest.mark.parametrize("backend", ["xgrammar", "llguidance"])
@@ -1041,6 +1098,39 @@ def test_gen_only_sync(self):
10411098
) as llm:
10421099
run_accuracy_test(llm, self.MODEL_NAME, ["GSM8K"])
10431100

1101+
@pytest.mark.skip_less_device(8)
1102+
@skip_pre_hopper
1103+
def test_gen_only_spec_dec(self):
1104+
ctx_server_config = {"disable_overlap_scheduler": True}
1105+
gen_server_config = {"disable_overlap_scheduler": False}
1106+
cache_transceiver_config = {
1107+
"backend": "NIXL",
1108+
"max_tokens_in_buffer": 4096,
1109+
"transceiver_runtime": "PYTHON",
1110+
}
1111+
ctx_server_config["cache_transceiver_config"] = cache_transceiver_config
1112+
gen_server_config["cache_transceiver_config"] = cache_transceiver_config
1113+
gen_server_config["speculative_config"] = {
1114+
"decoding_type": "MTP",
1115+
"max_draft_len": 2
1116+
}
1117+
disaggregated_server_config = {
1118+
"hostname": "localhost",
1119+
"backend": "pytorch",
1120+
"context_servers": {
1121+
"num_instances": 1
1122+
},
1123+
"generation_servers": {
1124+
"num_instances": 1
1125+
}
1126+
}
1127+
with launch_disaggregated_llm(disaggregated_server_config,
1128+
ctx_server_config,
1129+
gen_server_config,
1130+
self.MODEL_PATH,
1131+
tensor_parallel_size=4) as llm:
1132+
run_accuracy_test(llm, self.MODEL_NAME, ["MMLU", "GSM8K"])
1133+
10441134
@pytest.mark.skip_less_device(8)
10451135
@parametrize_with_ids("overlap_scheduler", [True, False])
10461136
@parametrize_with_ids("mtp_nextn", [0, 2])

tests/integration/test_lists/test-db/l0_dgx_h100.yml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,8 @@ l0_dgx_h100:
5858
- disaggregated/test_disaggregated.py::test_disaggregated_cancel_large_context_requests[DeepSeek-V3-Lite-bf16]
5959
# llmapi
6060
- unittest/llmapi/test_mpi_session.py::test_llmapi_launch_multiple_tasks
61+
- accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_gen_only_spec_dec
62+
- accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_gen_only_spec_dec
6163
- condition:
6264
ranges:
6365
system_gpu_count:

0 commit comments

Comments
 (0)