Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,9 @@ speculative_config:
num_nextn_predict_layers: 6
mtp_eagle_one_model: true
transforms:
compile_model:
# MTP speculative decoding does not support piecewise CUDA graph capture yet.
piecewise_enabled: false
detect_sharding:
allreduce_strategy: NCCL
# NOTE: add 'tp' to sharding dims only for high-throughput runs
Expand Down
14 changes: 14 additions & 0 deletions tensorrt_llm/_torch/auto_deploy/llm_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -443,6 +443,20 @@ def reject_cudagraph_for_speculative_flashinfer(self):
)
return self

@model_validator(mode="after")
def reject_piecewise_cuda_graph_for_speculative_decoding(self):
compile_model = self.transforms.get("compile_model", {})
if (
self.speculative_config is not None
and self.is_cuda_graph_enabled()
and compile_model.get("piecewise_enabled", False)
):
raise ValueError(
"Speculative decoding with AutoDeploy does not currently support piecewise CUDA "
"graph capture."
)
return self

@model_validator(mode="after")
def disable_piecewise_for_non_piecewise_backend(self):
compile_model = self.transforms.get("compile_model")
Expand Down
20 changes: 19 additions & 1 deletion tests/integration/defs/disaggregated/test_ad_disagg.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
import cloudpickle
import pytest
import torch
from defs.conftest import get_sm_version, skip_pre_hopper
from defs.conftest import check_device_contain, get_sm_version, skip_pre_hopper
from mpi4py import MPI
from mpi4py.futures import MPIPoolExecutor

Expand All @@ -41,6 +41,16 @@
pickle.HIGHEST_PROTOCOL,
)


@pytest.fixture(autouse=True)
def skip_b300():
if check_device_contain(["B300"]):
pytest.skip(
"AutoDeploy disagg tests are disabled on B300/GB300 until capacity is available: "
"https://nvbugs/6301621"
)


WORKER_READY = "ready"
REQUEST_MODE_AGGREGATE = "aggregate"
MPI_REQUEST = 9999
Expand Down Expand Up @@ -143,6 +153,12 @@ def seed_disagg():
torch.cuda.manual_seed_all(AUTODEPLOY_DISAGG_SEED)


def disable_piecewise_cuda_graph_for_speculation(config: dict) -> dict:
"""Disable piecewise CUDA graph capture for speculative AutoDeploy tests."""
config.setdefault("transforms", {}).setdefault("compile_model", {})["piecewise_enabled"] = False
return config


def base_config(extra_config=None):
common_config = dict(
runtime="trtllm",
Expand All @@ -157,6 +173,8 @@ def base_config(extra_config=None):
)
if extra_config:
common_config.update(extra_config)
if common_config.get("speculative_config") is not None:
disable_piecewise_cuda_graph_for_speculation(common_config)

return common_config

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
import pytest
import requests
from defs.common import get_free_port_in_ci as get_free_port
from defs.conftest import llm_models_root
from defs.conftest import check_device_contain, llm_models_root
from disagg_test_utils import (
CHECK_STATUS_INTERVAL,
HEARTBEAT_INTERVAL,
Expand All @@ -34,6 +34,16 @@

pytest_plugins = ["disagg_test_utils"]


@pytest.fixture(autouse=True)
def skip_b300():
if check_device_contain(["B300"]):
pytest.skip(
"AutoDeploy disagg tests are disabled on B300/GB300 until capacity is available: "
"https://nvbugs/6301621"
)


SERVER_START_TIMEOUT_S = 300
SERVER_READY_REQUEST_TIMEOUT_S = 5
OPENAI_REQUEST_TIMEOUT_S = 60
Expand Down
13 changes: 0 additions & 13 deletions tests/integration/test_lists/waives.txt
Original file line number Diff line number Diff line change
Expand Up @@ -76,18 +76,6 @@ cpp/test_multi_gpu.py::TestDisagg::test_symmetric_executor[gpt-2proc-mpi_kvcache
cpp/test_multi_gpu.py::TestDisagg::test_symmetric_executor[gpt-2proc-nixl_kvcache-90] SKIP (https://nvbugs/6093820)
cpp/test_multi_gpu.py::TestDisagg::test_symmetric_executor[gpt-2proc-ucx_kvcache-90] SKIP (https://nvbugs/6093820)
cpp/test_multi_gpu.py::test_cache_transceiver[8proc-mooncake_kvcache-90] SKIP (https://nvbugs/5838199)
disaggregated/test_ad_disagg.py::test_async_eagle3_full_model_handoff SKIP (https://nvbugs/6306936)
disaggregated/test_ad_disagg.py::test_async_generation_matches_aggregate SKIP (https://nvbugs/6306936)
disaggregated/test_ad_disagg.py::test_async_generation_no_overlap_matches_aggregate SKIP (https://nvbugs/6306936)
disaggregated/test_ad_disagg.py::test_async_sharded_generation_handoff SKIP (https://nvbugs/6306936)
disaggregated/test_ad_disagg.py::test_chunked_prefill_handoff[deepseek_v3_mla] SKIP (https://nvbugs/6306936)
disaggregated/test_ad_disagg.py::test_chunked_prefill_handoff[tinyllama] SKIP (https://nvbugs/6306936)
disaggregated/test_ad_disagg.py::test_disaggregated_logits[deepseek_v3_mla] SKIP (https://nvbugs/6306936)
disaggregated/test_ad_disagg.py::test_disaggregated_logits[tinyllama] SKIP (https://nvbugs/6306936)
disaggregated/test_ad_disagg.py::test_reduced_layer_handoff_matches_aggregate[deepseek_v3_mla] SKIP (https://nvbugs/6306936)
disaggregated/test_ad_disagg.py::test_reduced_layer_handoff_matches_aggregate[tinyllama] SKIP (https://nvbugs/6306936)
disaggregated/test_ad_disagg.py::test_tinyllama_batch_handoff_semantic_slots SKIP (https://nvbugs/6306936)
disaggregated/test_ad_disagg_trtllm_serve.py::test_openai_completion SKIP (https://nvbugs/6306936)
disaggregated/test_disaggregated.py::test_disaggregated_cancel_large_context_requests[DeepSeek-V3-Lite-bf16] SKIP (https://nvbugs/6105768)
disaggregated/test_disaggregated.py::test_disaggregated_deepseek_v3_lite_bf16_cache_aware_balance[DeepSeek-V3-Lite-bf16] SKIP (https://nvbugs/6162322)
disaggregated/test_disaggregated.py::test_disaggregated_deepseek_v3_lite_bf16_conditional[DeepSeek-V3-Lite-bf16] SKIP (https://nvbugs/6162322)
Expand Down Expand Up @@ -445,7 +433,6 @@ unittest/_torch/multi_gpu/test_user_buffers.py::test_user_buffers_pass[2-fp16-_t
unittest/_torch/thop/serial/test_moe.py::TestMoeFp4::test_no_autotune[use_score_as_input-RoutingDSv3-swiglu-1024-1024-1] SKIP (https://nvbugs/5908070)
unittest/_torch/thop/serial/test_moe.py::TestMoeFp4::test_no_autotune[use_score_as_input-RoutingRenormalize_qwen_next-swiglu-1024-1024-150] SKIP (https://nvbugs/5908070)
unittest/_torch/thop/serial/test_moe.py::TestMoeFp4::test_no_autotune[use_score_as_input-RoutingRenormalize_topk_4-swiglu-1024-1024-150] SKIP (https://nvbugs/5908070)
unittest/auto_deploy/singlegpu/smoke SKIP (https://nvbugs/6306936)
unittest/bindings/test_transfer_agent_bindings.py::TestNixlFunctionalTransfer::test_nixl_wait_in_progress_on_zero_timeout SKIP (https://nvbugs/6260897)
unittest/executor/test_rpc.py::TestRpcCorrectness::test_incremental_task_async SKIP (https://nvbugs/5741476)
unittest/executor/test_rpc_proxy.py SKIP (https://nvbugs/5605741)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -641,6 +641,7 @@ def get_small_model_config(model_hub_id: str, **llm_args_kwargs) -> Dict[str, An
"free_gpu_memory_fraction": 0.0, # No resizing of the cache to keep the mem footprint small
}
llm_args["max_batch_size"] = 2 # Minimum batching to speed up things
llm_args["max_seq_len"] = 256
llm_args["cuda_graph_config"] = {"max_batch_size": 2} # Match max_batch_size
# update with custom llm_args kwargs
llm_args.update(llm_args_kwargs)
Expand Down
21 changes: 18 additions & 3 deletions tests/unittest/auto_deploy/singlegpu/shim/test_llm_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -275,6 +275,10 @@ class TestSpeculativeConfigValidation:
Verify that supported speculative modes are accepted and configured before executor setup.
"""

@staticmethod
def piecewise_disabled_transforms():
return {"compile_model": {"piecewise_enabled": False}}

def test_accepts_eagle_one_model(self):
from tensorrt_llm.llmapi import EagleDecodingConfig

Expand All @@ -284,7 +288,11 @@ def test_accepts_eagle_one_model(self):
eagle3_one_model=True,
)
# Should not raise.
args = LlmArgs(model="test-model", speculative_config=spec_config)
args = LlmArgs(
model="test-model",
speculative_config=spec_config,
transforms=self.piecewise_disabled_transforms(),
)
assert args.model_factory == "eagle_one_model"

def test_accepts_mtp_eagle_one_model(self):
Expand All @@ -295,7 +303,11 @@ def test_accepts_mtp_eagle_one_model(self):
mtp_eagle_one_model=True,
)
# Should not raise.
args = LlmArgs(model="test-model", speculative_config=spec_config)
args = LlmArgs(
model="test-model",
speculative_config=spec_config,
transforms=self.piecewise_disabled_transforms(),
)
assert args.model_factory == "eagle_one_model"

@pytest.mark.parametrize("compile_backend", ["torch-cudagraph", "torch-opt"])
Expand Down Expand Up @@ -356,7 +368,10 @@ def test_ssm_replay_with_spec_ok(self):
args = LlmArgs(
model="test-model",
speculative_config=spec_config,
transforms={"insert_cached_ssm_attention": {"ssm_replay": True}},
transforms={
"compile_model": {"piecewise_enabled": False},
"insert_cached_ssm_attention": {"ssm_replay": True},
},
)
assert args.transforms["insert_cached_ssm_attention"]["ssm_replay"] is True

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,10 @@ def get_extra_seq_len_for_kv_cache(llm_args) -> int:
return extra


def piecewise_disabled_transforms():
return {"compile_model": {"piecewise_enabled": False}}


def test_super_mtp_smoke():
"""Test one-model MTP/Eagle runtime with a tiny Nemotron SuperV3 target."""
test_prompt = "What is the capital of France?"
Expand Down Expand Up @@ -190,6 +194,7 @@ def test_kv_cache_extra_seq_len_for_spec_dec():
model="meta-llama/Meta-Llama-3.1-8B-Instruct",
speculative_config=spec_config,
disable_overlap_scheduler=True,
transforms=piecewise_disabled_transforms(),
)
extra = get_extra_seq_len_for_kv_cache(args_eagle)
# Should include max_total_draft_tokens + get_num_extra_kv_tokens (max_draft_len - 1)
Expand All @@ -201,6 +206,7 @@ def test_kv_cache_extra_seq_len_for_spec_dec():
model="meta-llama/Meta-Llama-3.1-8B-Instruct",
speculative_config=spec_config,
disable_overlap_scheduler=False,
transforms=piecewise_disabled_transforms(),
)
extra_overlap = get_extra_seq_len_for_kv_cache(args_eagle_overlap)
# Should be more than without overlap
Expand All @@ -217,6 +223,7 @@ def test_mtp_autodeploy_uses_eagle_one_model_capture():
num_nextn_predict_layers=3,
mtp_eagle_one_model=True,
),
transforms=piecewise_disabled_transforms(),
)

assert isinstance(args.speculative_config, MTPDecodingConfig)
Expand All @@ -229,6 +236,9 @@ def test_detect_hidden_states_capture_last_layer_for_mtp_eagle_one_model():
from tensorrt_llm._torch.auto_deploy.llm_args import LlmArgs

config = get_small_model_config("meta-llama/Meta-Llama-3.1-8B-Instruct")
config["args"].setdefault("transforms", {}).setdefault("compile_model", {})[
"piecewise_enabled"
] = False

args = LlmArgs(
**config["args"],
Expand Down
Loading