Skip to content

Commit d66f329

Browse files
fixing smoke test failures; mostly flashinfer shapes
Signed-off-by: Govind Ramnarayan <105831528+govind-ramnarayan@users.noreply.github.com>
1 parent d18e1a2 commit d66f329

4 files changed

Lines changed: 47 additions & 4 deletions

File tree

tests/unittest/auto_deploy/singlegpu/smoke/test_ad_build_small_single.py

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -214,7 +214,10 @@ def _check_ad_config(experiment_config: ExperimentConfig, llm_args: LlmArgs):
214214
"mistralai/Mistral-Small-3.1-24B-Instruct-2503",
215215
{
216216
"transforms": {
217-
"insert_cached_attention": {"backend": "flashinfer"},
217+
"insert_cached_attention": {
218+
"backend": "flashinfer",
219+
"requires_shape_prop": True,
220+
},
218221
"compile_model": {
219222
"backend": "torch-cudagraph",
220223
"cuda_graph_batch_sizes": [1, 2],
@@ -295,6 +298,22 @@ def test_build_ad(model_hub_id: str, llm_extra_args: dict):
295298
experiment_config = get_small_model_config(model_hub_id, **llm_extra_args)
296299
experiment_config["args"]["runtime"] = "demollm" # Default runtime set to demollm
297300
experiment_config["args"]["world_size"] = 0 # Default world_size set to 0
301+
if (
302+
model_hub_id == "mistralai/Mistral-Small-3.1-24B-Instruct-2503"
303+
and llm_extra_args.get("transforms", {})
304+
.get("compile_model", {})
305+
.get("backend")
306+
== "torch-cudagraph"
307+
):
308+
experiment_config["args"]["max_batch_size"] = 1
309+
experiment_config["args"]["max_input_len"] = 64
310+
experiment_config["args"]["max_seq_len"] = 128
311+
experiment_config["args"]["max_num_tokens"] = 128
312+
experiment_config["args"]["cuda_graph_config"] = {
313+
"batch_sizes": [1],
314+
"max_batch_size": 1,
315+
}
316+
experiment_config["prompt"]["batch_size"] = 1
298317

299318
print(f"Experiment Config: {experiment_config}")
300319
experiment_config = ExperimentConfig(**experiment_config)

tests/unittest/auto_deploy/singlegpu/smoke/test_ad_guided_decoding_regex.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
1+
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
22
# SPDX-License-Identifier: Apache-2.0
33
#
44
# Licensed under the Apache License, Version 2.0 (the "License");
@@ -37,6 +37,14 @@ def test_ad_guided_decoding_regex_e2e():
3737
# NOTE: trtllm attention backend fails on B200 (likely illegal memory access); use flashinfer.
3838
experiment_config["args"]["attn_backend"] = "flashinfer"
3939
experiment_config["args"]["guided_decoding_backend"] = guided_decoding_backend
40+
experiment_config["args"]["max_batch_size"] = 1
41+
experiment_config["args"]["max_input_len"] = 64
42+
experiment_config["args"]["max_seq_len"] = 128
43+
experiment_config["args"]["max_num_tokens"] = 128
44+
experiment_config["args"]["cuda_graph_config"] = {
45+
"batch_sizes": [1],
46+
"max_batch_size": 1,
47+
}
4048

4149
experiment_config["prompt"]["batch_size"] = 1
4250
experiment_config["prompt"]["queries"] = test_case["prompt"]
@@ -46,7 +54,7 @@ def test_ad_guided_decoding_regex_e2e():
4654
# Need to introduce the guided decoding params after ExperimentConfig construction
4755
# because otherwise they get unpacked as a dict.
4856
cfg.prompt.sp_kwargs = {
49-
"max_tokens": 10,
57+
"max_tokens": 16,
5058
"top_k": None,
5159
"temperature": 0.1,
5260
"guided_decoding": GuidedDecodingParams(regex=test_case["regex"]),

tests/unittest/auto_deploy/singlegpu/smoke/test_ad_trtllm_sampler.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
1+
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
22
# SPDX-License-Identifier: Apache-2.0
33
#
44
# Licensed under the Apache License, Version 2.0 (the "License");
@@ -31,6 +31,14 @@ def test_ad_trtllm_sampler_smoke():
3131
# NOTE: trtllm attention backend fails on B200 (likely illegal memory access); use flashinfer.
3232
experiment_config["args"]["attn_backend"] = "flashinfer"
3333
experiment_config["args"]["sampler_type"] = SamplerType.TRTLLMSampler
34+
experiment_config["args"]["max_batch_size"] = 1
35+
experiment_config["args"]["max_input_len"] = 64
36+
experiment_config["args"]["max_seq_len"] = 128
37+
experiment_config["args"]["max_num_tokens"] = 128
38+
experiment_config["args"]["cuda_graph_config"] = {
39+
"batch_sizes": [1],
40+
"max_batch_size": 1,
41+
}
3442

3543
# Setup simple prompt
3644
experiment_config["prompt"]["batch_size"] = 1

tests/unittest/auto_deploy/singlegpu/smoke/test_ad_trtllm_serve.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,14 @@ def test_trtllm_serve_openai_chat_completion(tmp_path):
4141

4242
# NOTE: trtllm attention backend fails on B200 (likely illegal memory access); use flashinfer.
4343
extra_args["attn_backend"] = "flashinfer"
44+
extra_args["max_batch_size"] = 1
45+
extra_args["max_input_len"] = 64
46+
extra_args["max_seq_len"] = 128
47+
extra_args["max_num_tokens"] = 128
48+
extra_args["cuda_graph_config"] = {
49+
"batch_sizes": [1],
50+
"max_batch_size": 1,
51+
}
4452
extra_options_path = tmp_path / "extra_llm_api_options.yaml"
4553
with open(extra_options_path, "w") as f:
4654
yaml.safe_dump(extra_args, f)

0 commit comments

Comments
 (0)