Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 6 additions & 3 deletions tensorrt_llm/_torch/disaggregation/native/peer.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,11 +97,14 @@ def _check_peer_compatible(self, peer_ri: RankInfo) -> bool:
self_layers = sum(self._ri.layer_num_per_pp)
peer_layers = sum(peer_ri.layer_num_per_pp)
if self_layers != peer_layers:
# Allow mismatch when one side has speculative (e.g. MTP) layers
# that the other side doesn't. The pool_mapping logic will only
# transfer layers that exist on both sides.
logger.warning(
"PeerRegistrar: total layer count mismatch "
f"(local={self_layers}, peer={peer_layers})."
"PeerRegistrar: layer count differs "
f"(local={self_layers}, peer={peer_layers}), "
"allowing partial layer transfer."
)
Comment thread
bo-nv marked this conversation as resolved.
return False

return True

Expand Down
7 changes: 6 additions & 1 deletion tensorrt_llm/_torch/disaggregation/transceiver.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,11 +178,16 @@ def _create_kv_slice(
groups.append(np.array([], dtype=np.int64))
continue
block_ids = adapter.get_block_ids(req, idx, lg)
# Limit to prompt_len blocks, matching C++ cacheFormatter behavior.
# Extra blocks from num_extra_kv_tokens (speculative decoding) have
# uninitialized KV data and must not be transferred.
total_blocks = (req.prompt_len + tpb - 1) // tpb

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here, along with token_range, the requirement is to only transfer blocks for prompt_len. However, in practice, prompt + num_extra_kv_tokens blocks are allocated.

If MTP is enabled for both the context phase and the generation phase, then the current modification will only transfer prompt_len
blocks, and the extra block that may will not be transferred. The questions are:

  1. Will the KV cache for num_extra_kv_tokens be written to during the context phase?
  2. Will the KV cache written during the context phase be used by the generation phase?
  3. When both context and generation have MTP enabled, do we need to transfer prompt_len + num_extra_kv_tokens KV blocks?
    cc @lfr-0531

@bo-nv bo-nv Jun 10, 2026

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@chuangz0 these changes fix the accuracy issue with py-transceiver + eagle3. Without them,accuracy drops even if enable eagle3 for both ctx and gen.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Merge first

if block_ids.size > total_blocks:
block_ids = block_ids[:total_blocks]
window_size = lg.sliding_window_size

if window_size is not None:
# Drop stale blocks the manager may still expose (V1 pre-eviction).
total_blocks = (req.prompt_len + tpb - 1) // tpb
stale_end = max(0, (req.prompt_len + 1 - window_size) // tpb)
expected_valid = max(0, total_blocks - stale_end)
if block_ids.size > expected_valid:
Expand Down
8 changes: 8 additions & 0 deletions tensorrt_llm/_torch/pyexecutor/py_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -4214,6 +4214,14 @@ def _prepare_disagg_gen_transmission_complete(self, scheduled_batch):
req.py_kv_transfer_timed_out = False
first_gen_tokens = req.context_phase_params.first_gen_tokens
ctx_draft_tokens = req.context_phase_params.draft_tokens
if not ctx_draft_tokens and self.model_engine.enable_spec_decode:
Comment thread
Shixiaowei02 marked this conversation as resolved.
# CTX has no speculative decoding — fill dummy draft tokens
# so model_engine builds the correct input shape (1 + draft_len
# tokens per gen request). Dummies will be rejected on verify,
# and the draft model will produce real tokens for next step.
ctx_draft_tokens = [
0
] * self.model_engine.max_total_draft_tokens
req.py_draft_tokens = [] if ctx_draft_tokens is None else ctx_draft_tokens
beam_width = req.py_beam_width
for beam in range(0, beam_width):
Expand Down
90 changes: 90 additions & 0 deletions tests/integration/defs/accuracy/test_disaggregated_serving.py
Original file line number Diff line number Diff line change
Expand Up @@ -780,6 +780,63 @@ def test_eagle3(self, overlap_scheduler, eagle3_one_model):
self.MODEL_PATH) as llm:
run_accuracy_test(llm, self.MODEL_NAME, ["GSM8K"])

@pytest.mark.skip_less_device(2)
@skip_pre_hopper
def test_gen_only_spec_dec(self):
speculative_decoding_config = {
"decoding_type": "Eagle",
"max_draft_len": 4,
"speculative_model":
f"{llm_models_root()}/EAGLE3-LLaMA3.1-Instruct-8B",
"eagle3_one_model": True,
}
ctx_server_config = {
"disable_overlap_scheduler":
True, # BS=1 does not need overlap scheduling
"kv_cache_config": {
"free_gpu_memory_fraction": 0.5,
"enable_block_reuse": True # reuse on context requests
},
"max_num_tokens": 13393 * 2,
"max_batch_size": 1,
"cache_transceiver_config": {
"backend": "NIXL",
"transceiver_runtime": "PYTHON",
"max_tokens_in_buffer": 4096,
},
"cuda_graph_config": None,
}
gen_server_config = {
"disable_overlap_scheduler": False,
"speculative_config": speculative_decoding_config,
"kv_cache_config": {
"free_gpu_memory_fraction": 0.5,
"enable_block_reuse": False
},
"max_num_tokens": 13393 * 2,
"max_batch_size": 16,
"cache_transceiver_config": {
"backend": "NIXL",
"transceiver_runtime": "PYTHON",
"max_tokens_in_buffer": 4096,
},
"cuda_graph_config": None,
}
disaggregated_server_config = {
"hostname": "localhost",
"backend": "pytorch",
"context_servers": {
"num_instances": 1
},
"generation_servers": {
"num_instances": 1
}
}
with launch_disaggregated_llm(disaggregated_server_config,
ctx_server_config, gen_server_config,
self.MODEL_PATH) as llm:
run_accuracy_test(llm, self.MODEL_NAME, ["GSM8K"])

@pytest.mark.skip_less_device(2)
@pytest.mark.skip_less_device_memory(32000)
@pytest.mark.parametrize("backend", ["xgrammar", "llguidance"])
Expand Down Expand Up @@ -1001,6 +1058,39 @@ def test_gen_only_sync(self):
) as llm:
run_accuracy_test(llm, self.MODEL_NAME, ["GSM8K"])

@pytest.mark.skip_less_device(8)
Comment thread
bo-nv marked this conversation as resolved.
@skip_pre_hopper
def test_gen_only_spec_dec(self):
ctx_server_config = {"disable_overlap_scheduler": True}
gen_server_config = {"disable_overlap_scheduler": False}
cache_transceiver_config = {
"backend": "NIXL",
"max_tokens_in_buffer": 4096,
"transceiver_runtime": "PYTHON",
}
ctx_server_config["cache_transceiver_config"] = cache_transceiver_config
gen_server_config["cache_transceiver_config"] = cache_transceiver_config
gen_server_config["speculative_config"] = {
"decoding_type": "MTP",
"max_draft_len": 2
}
disaggregated_server_config = {
"hostname": "localhost",
"backend": "pytorch",
"context_servers": {
"num_instances": 1
},
"generation_servers": {
"num_instances": 1
}
}
with launch_disaggregated_llm(disaggregated_server_config,
ctx_server_config,
gen_server_config,
self.MODEL_PATH,
tensor_parallel_size=4) as llm:
run_accuracy_test(llm, self.MODEL_NAME, ["MMLU", "GSM8K"])

@pytest.mark.skip_less_device(8)
@parametrize_with_ids("overlap_scheduler", [True, False])
@parametrize_with_ids("mtp_nextn", [0, 2])
Expand Down
1 change: 1 addition & 0 deletions tests/integration/test_lists/test-db/l0_dgx_h100.yml
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ l0_dgx_h100:
- disaggregated/test_disaggregated.py::test_disaggregated_cancel_large_context_requests[DeepSeek-V3-Lite-bf16]
# llmapi
- unittest/llmapi/test_mpi_session.py::test_llmapi_launch_multiple_tasks
- accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_gen_only_spec_dec
- condition:
ranges:
system_gpu_count:
Expand Down
1 change: 1 addition & 0 deletions tests/integration/test_lists/test-db/l0_dgx_h200.yml
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ l0_dgx_h200:
- disaggregated/test_disaggregated.py::test_disaggregated_ctxpp4_genpp4[TinyLlama-1.1B-Chat-v1.0]
- disaggregated/test_disaggregated.py::test_disaggregated_deepseek_v3_lite_fp8_ctxtp2ep2pp2_gentp4_one_mtp_block_reuse[DeepSeek-V3-Lite-fp8]
- unittest/llmapi/test_llm_pytorch.py::test_nemotron_nas_lora
- accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_gen_only_spec_dec
- condition:
ranges:
system_gpu_count:
Expand Down
Loading