Skip to content

Commit e419a17

Browse files
[None][test] Reject AD piecewise cudagraph with speculation
Signed-off-by: Govind Ramnarayan <105831528+govind-ramnarayan@users.noreply.github.com>
1 parent 831ca36 commit e419a17

3 files changed

Lines changed: 42 additions & 3 deletions

File tree

tensorrt_llm/_torch/auto_deploy/llm_args.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -443,6 +443,20 @@ def reject_cudagraph_for_speculative_flashinfer(self):
443443
)
444444
return self
445445

446+
@model_validator(mode="after")
447+
def reject_piecewise_cuda_graph_for_speculative_decoding(self):
448+
compile_model = self.transforms.get("compile_model", {})
449+
if (
450+
self.speculative_config is not None
451+
and self.is_cuda_graph_enabled()
452+
and compile_model.get("piecewise_enabled", False)
453+
):
454+
raise ValueError(
455+
"Speculative decoding with AutoDeploy does not currently support piecewise CUDA "
456+
"graph capture."
457+
)
458+
return self
459+
446460
@model_validator(mode="after")
447461
def disable_piecewise_for_non_piecewise_backend(self):
448462
compile_model = self.transforms.get("compile_model")

tests/unittest/auto_deploy/singlegpu/shim/test_llm_config.py

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -275,6 +275,10 @@ class TestSpeculativeConfigValidation:
275275
Verify that supported speculative modes are accepted and configured before executor setup.
276276
"""
277277

278+
@staticmethod
279+
def piecewise_disabled_transforms():
280+
return {"compile_model": {"piecewise_enabled": False}}
281+
278282
def test_accepts_eagle_one_model(self):
279283
from tensorrt_llm.llmapi import EagleDecodingConfig
280284

@@ -284,7 +288,11 @@ def test_accepts_eagle_one_model(self):
284288
eagle3_one_model=True,
285289
)
286290
# Should not raise.
287-
args = LlmArgs(model="test-model", speculative_config=spec_config)
291+
args = LlmArgs(
292+
model="test-model",
293+
speculative_config=spec_config,
294+
transforms=self.piecewise_disabled_transforms(),
295+
)
288296
assert args.model_factory == "eagle_one_model"
289297

290298
def test_accepts_mtp_eagle_one_model(self):
@@ -295,7 +303,11 @@ def test_accepts_mtp_eagle_one_model(self):
295303
mtp_eagle_one_model=True,
296304
)
297305
# Should not raise.
298-
args = LlmArgs(model="test-model", speculative_config=spec_config)
306+
args = LlmArgs(
307+
model="test-model",
308+
speculative_config=spec_config,
309+
transforms=self.piecewise_disabled_transforms(),
310+
)
299311
assert args.model_factory == "eagle_one_model"
300312

301313
@pytest.mark.parametrize("compile_backend", ["torch-cudagraph", "torch-opt"])
@@ -356,7 +368,10 @@ def test_ssm_replay_with_spec_ok(self):
356368
args = LlmArgs(
357369
model="test-model",
358370
speculative_config=spec_config,
359-
transforms={"insert_cached_ssm_attention": {"ssm_replay": True}},
371+
transforms={
372+
"compile_model": {"piecewise_enabled": False},
373+
"insert_cached_ssm_attention": {"ssm_replay": True},
374+
},
360375
)
361376
assert args.transforms["insert_cached_ssm_attention"]["ssm_replay"] is True
362377

tests/unittest/auto_deploy/singlegpu/smoke/test_ad_speculative_decoding.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,10 @@ def get_extra_seq_len_for_kv_cache(llm_args) -> int:
4545
return extra
4646

4747

48+
def piecewise_disabled_transforms():
49+
return {"compile_model": {"piecewise_enabled": False}}
50+
51+
4852
def test_super_mtp_smoke():
4953
"""Test one-model MTP/Eagle runtime with a tiny Nemotron SuperV3 target."""
5054
test_prompt = "What is the capital of France?"
@@ -190,6 +194,7 @@ def test_kv_cache_extra_seq_len_for_spec_dec():
190194
model="meta-llama/Meta-Llama-3.1-8B-Instruct",
191195
speculative_config=spec_config,
192196
disable_overlap_scheduler=True,
197+
transforms=piecewise_disabled_transforms(),
193198
)
194199
extra = get_extra_seq_len_for_kv_cache(args_eagle)
195200
# Should include max_total_draft_tokens + get_num_extra_kv_tokens (max_draft_len - 1)
@@ -201,6 +206,7 @@ def test_kv_cache_extra_seq_len_for_spec_dec():
201206
model="meta-llama/Meta-Llama-3.1-8B-Instruct",
202207
speculative_config=spec_config,
203208
disable_overlap_scheduler=False,
209+
transforms=piecewise_disabled_transforms(),
204210
)
205211
extra_overlap = get_extra_seq_len_for_kv_cache(args_eagle_overlap)
206212
# Should be more than without overlap
@@ -217,6 +223,7 @@ def test_mtp_autodeploy_uses_eagle_one_model_capture():
217223
num_nextn_predict_layers=3,
218224
mtp_eagle_one_model=True,
219225
),
226+
transforms=piecewise_disabled_transforms(),
220227
)
221228

222229
assert isinstance(args.speculative_config, MTPDecodingConfig)
@@ -229,6 +236,9 @@ def test_detect_hidden_states_capture_last_layer_for_mtp_eagle_one_model():
229236
from tensorrt_llm._torch.auto_deploy.llm_args import LlmArgs
230237

231238
config = get_small_model_config("meta-llama/Meta-Llama-3.1-8B-Instruct")
239+
config["args"].setdefault("transforms", {}).setdefault("compile_model", {})[
240+
"piecewise_enabled"
241+
] = False
232242

233243
args = LlmArgs(
234244
**config["args"],

0 commit comments

Comments
 (0)