Skip to content

Commit 4f25485

Browse files
authored
feat: configure cudagraph capture batch sizes (#4573)
* feat: configure cudagraph capture batch sizes * chore: simplify cudagraph capture sizes * fix: require cudagraph capture coverage * fix: normalize cudagraph capture sizes
1 parent d0ba19c commit 4f25485

8 files changed

Lines changed: 120 additions & 2 deletions

File tree

lmdeploy/cli/serve.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -128,6 +128,7 @@ def add_parser_api_server():
128128
ArgumentHelper.enable_eplb(pt_group)
129129
ArgumentHelper.role(pt_group)
130130
ArgumentHelper.migration_backend(pt_group)
131+
ArgumentHelper.cudagraph_capture_batch_sizes(pt_group)
131132
# multi-node serving args
132133
node_rank_act = ArgumentHelper.node_rank(pt_group)
133134
num_nodes_act = ArgumentHelper.num_nodes(pt_group)
@@ -237,6 +238,7 @@ def api_server(args):
237238
quant_policy=args.quant_policy,
238239
eager_mode=args.eager_mode,
239240
max_prefill_token_num=args.max_prefill_token_num,
241+
cudagraph_capture_batch_sizes=args.cudagraph_capture_batch_sizes,
240242
enable_microbatch=args.enable_microbatch,
241243
enable_eplb=args.enable_eplb,
242244
enable_metrics=not args.disable_metrics,

lmdeploy/cli/utils.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -616,6 +616,16 @@ def max_prefill_token_num(parser):
616616
default=8192,
617617
help='the max number of tokens per iteration during prefill')
618618

619+
@staticmethod
620+
def cudagraph_capture_batch_sizes(parser):
621+
return parser.add_argument('--cudagraph-capture-batch-sizes',
622+
type=int,
623+
nargs='+',
624+
default=None,
625+
help='Batch sizes to capture CUDA graphs for in the PyTorch engine. '
626+
'If not specified, the engine infers them from max_batch_size. '
627+
'max_batch_size is always captured')
628+
619629
@staticmethod
620630
def vision_max_batch_size(parser):
621631
return parser.add_argument('--vision-max-batch-size', type=int, default=1, help='the vision model batch size')

lmdeploy/messages.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -362,6 +362,9 @@ class PytorchEngineConfig:
362362
would be allocate according to current environment.
363363
adapters: The path configs to lora adapters.
364364
max_prefill_token_num: tokens per iteration.
365+
cudagraph_capture_batch_sizes: Batch sizes to capture CUDA graphs for.
366+
If not specified, the engine will infer them from max_batch_size.
367+
max_batch_size is always captured.
365368
thread_safe: thread safe engine instance.
366369
enable_prefix_caching: Enable token match and sharing caches.
367370
device_type: The inference device type, options ['cuda']
@@ -422,6 +425,7 @@ class PytorchEngineConfig:
422425
num_gpu_blocks: int = 0
423426
adapters: dict[str, str] = None
424427
max_prefill_token_num: int = 8192
428+
cudagraph_capture_batch_sizes: list[int] | None = None
425429
thread_safe: bool = False
426430
enable_prefix_caching: bool = False
427431
device_type: str = 'cuda'

lmdeploy/pytorch/backends/cuda/graph_runner.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,12 @@
77

88
from lmdeploy.pytorch.backends.deepep_state import get_deepep_state
99
from lmdeploy.pytorch.backends.selector import get_backend
10-
from lmdeploy.pytorch.config import BackendConfig, CacheConfig, ModelConfig
10+
from lmdeploy.pytorch.config import (
11+
BackendConfig,
12+
CacheConfig,
13+
ModelConfig,
14+
normalize_cudagraph_capture_batch_sizes,
15+
)
1116
from lmdeploy.pytorch.envs import fake_capture
1217
from lmdeploy.pytorch.model_inputs import StepContext, get_step_ctx_manager
1318
from lmdeploy.pytorch.models.utils.cudagraph import CudaGraphMeta
@@ -342,4 +347,8 @@ def update_inputs(self, inputs):
342347

343348
def get_capture_batch_sizes(self) -> list[int]:
344349
"""Capture batch sizes."""
350+
if self.cache_config.cudagraph_capture_batch_sizes is not None:
351+
self.cache_config.cudagraph_capture_batch_sizes = normalize_cudagraph_capture_batch_sizes(
352+
self.cache_config.cudagraph_capture_batch_sizes, self.cache_config.max_batches)
353+
return self.cache_config.cudagraph_capture_batch_sizes
345354
return _get_capture_batch_size_impl(self.cache_config.max_batches)

lmdeploy/pytorch/backends/graph_runner.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,12 @@
44

55
import torch
66

7-
from lmdeploy.pytorch.config import BackendConfig, CacheConfig, ModelConfig
7+
from lmdeploy.pytorch.config import (
8+
BackendConfig,
9+
CacheConfig,
10+
ModelConfig,
11+
normalize_cudagraph_capture_batch_sizes,
12+
)
813
from lmdeploy.pytorch.model_inputs import StepContext
914

1015

@@ -101,4 +106,8 @@ def update_inputs(self, inputs):
101106

102107
def get_capture_batch_sizes(self) -> list[int]:
103108
"""Capture batch sizes."""
109+
if self.cache_config.cudagraph_capture_batch_sizes is not None:
110+
self.cache_config.cudagraph_capture_batch_sizes = normalize_cudagraph_capture_batch_sizes(
111+
self.cache_config.cudagraph_capture_batch_sizes, self.cache_config.max_batches)
112+
return self.cache_config.cudagraph_capture_batch_sizes
104113
return _get_capture_batch_size_impl(self.cache_config.max_batches)

lmdeploy/pytorch/config.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,24 @@
1414
logger = get_logger('lmdeploy')
1515

1616

17+
def normalize_cudagraph_capture_batch_sizes(capture_sizes: list[int] | None, max_batches: int) -> list[int] | None:
18+
"""Normalize configured cudagraph capture batch sizes."""
19+
if capture_sizes is None:
20+
return None
21+
22+
assert len(capture_sizes) > 0, 'cudagraph_capture_batch_sizes should not be empty'
23+
assert all(isinstance(size, int) and size > 0 for size in capture_sizes), (
24+
'cudagraph_capture_batch_sizes should be positive integers')
25+
26+
capture_sizes = sorted({size for size in capture_sizes if size <= max_batches})
27+
assert len(capture_sizes) > 0, (
28+
'cudagraph_capture_batch_sizes should contain at least one value '
29+
f'<= max_batch_size ({max_batches})')
30+
if capture_sizes[-1] != max_batches:
31+
capture_sizes.append(max_batches)
32+
return capture_sizes
33+
34+
1735
def _update_torch_dtype(config: 'ModelConfig', dtype: str, device_type: str = 'auto'):
1836
"""Update the torch dtype from the model config.
1937
@@ -98,6 +116,7 @@ class CacheConfig:
98116
window_size: int = -1
99117
cache_max_entry_count: float = 0.8
100118
max_prefill_token_num: int = 8192
119+
cudagraph_capture_batch_sizes: list[int] | None = None
101120
enable_prefix_caching: bool = False
102121
quant_policy: QuantPolicy = QuantPolicy.NONE
103122
device_type: str = 'cuda'
@@ -118,6 +137,8 @@ def __post_init__(self):
118137
self.enable_prefix_caching = False
119138
if self.kernel_block_size == -1:
120139
self.kernel_block_size = self.block_size
140+
self.cudagraph_capture_batch_sizes = normalize_cudagraph_capture_batch_sizes(
141+
self.cudagraph_capture_batch_sizes, self.max_batches)
121142

122143

123144
class TPMode(enum.Enum):
@@ -611,6 +632,7 @@ def from_config(
611632
num_gpu_blocks=target_cache_cfg.num_gpu_blocks,
612633
cache_max_entry_count=target_cache_cfg.cache_max_entry_count,
613634
max_prefill_token_num=target_cache_cfg.max_prefill_token_num,
635+
cudagraph_capture_batch_sizes=target_cache_cfg.cudagraph_capture_batch_sizes,
614636
device_type=target_cache_cfg.device_type,
615637
quant_policy=target_cache_cfg.quant_policy,
616638
migration_backend=target_cache_cfg.migration_backend)

lmdeploy/pytorch/engine/config_builder.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
MiscConfig,
1111
SchedulerConfig,
1212
SpecDecodeConfig,
13+
normalize_cudagraph_capture_batch_sizes,
1314
)
1415
from lmdeploy.utils import get_logger, get_max_batch_size, get_model
1516

@@ -39,6 +40,11 @@ def update_engine_config(engine_config: PytorchEngineConfig):
3940
f'since dllm_block_length({engine_config.dllm_block_length}) * max_batch_size '
4041
f'({max_batch_size}) > max_prefill_token_num ({max_prefill_token_num}).')
4142

43+
capture_sizes = engine_config.cudagraph_capture_batch_sizes
44+
if capture_sizes is not None:
45+
engine_config.cudagraph_capture_batch_sizes = normalize_cudagraph_capture_batch_sizes(
46+
capture_sizes, engine_config.max_batch_size)
47+
4248
if engine_config.dp != 1:
4349
if engine_config.tp == 1 and engine_config.ep == 1:
4450
logger.warning('Data parallelism is enabled but tensor parallelism and '
@@ -67,6 +73,7 @@ def build_cache_config(engine_config: PytorchEngineConfig):
6773
num_gpu_blocks=engine_config.num_gpu_blocks,
6874
cache_max_entry_count=engine_config.cache_max_entry_count,
6975
max_prefill_token_num=engine_config.max_prefill_token_num,
76+
cudagraph_capture_batch_sizes=engine_config.cudagraph_capture_batch_sizes,
7077
enable_prefix_caching=engine_config.enable_prefix_caching,
7178
quant_policy=engine_config.quant_policy,
7279
device_type=engine_config.device_type,
Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
# Copyright (c) OpenMMLab. All rights reserved.
2+
import pytest
3+
4+
from lmdeploy.messages import PytorchEngineConfig
5+
from lmdeploy.pytorch.backends.cuda.graph_runner import CUDAGraphRunner
6+
from lmdeploy.pytorch.config import CacheConfig
7+
from lmdeploy.pytorch.engine.config_builder import ConfigBuilder
8+
9+
10+
def _cache_config(max_batches=8, cudagraph_capture_batch_sizes=None):
11+
return CacheConfig(max_batches=max_batches,
12+
block_size=64,
13+
num_cpu_blocks=0,
14+
num_gpu_blocks=1,
15+
cudagraph_capture_batch_sizes=cudagraph_capture_batch_sizes)
16+
17+
18+
def test_custom_capture_batch_sizes_include_max_batch_size():
19+
engine_config = PytorchEngineConfig(max_batch_size=8, cudagraph_capture_batch_sizes=[4, 1, 4, 16])
20+
21+
engine_config = ConfigBuilder.update_engine_config(engine_config)
22+
23+
assert engine_config.cudagraph_capture_batch_sizes == [1, 4, 8]
24+
25+
26+
def test_cache_config_normalizes_capture_batch_sizes():
27+
cache_config = _cache_config(max_batches=8, cudagraph_capture_batch_sizes=[4, 1, 4, 16])
28+
29+
assert cache_config.cudagraph_capture_batch_sizes == [1, 4, 8]
30+
31+
32+
@pytest.mark.parametrize('sizes', [[], [0], [-1], [1.5], ['1'], [16]])
33+
def test_invalid_capture_batch_sizes_raise(sizes):
34+
with pytest.raises(AssertionError):
35+
_cache_config(max_batches=8, cudagraph_capture_batch_sizes=sizes)
36+
37+
38+
def test_capture_batch_size_miss_raises():
39+
engine_config = PytorchEngineConfig(max_batch_size=8, cudagraph_capture_batch_sizes=[1, 4])
40+
engine_config = ConfigBuilder.update_engine_config(engine_config)
41+
runner = object.__new__(CUDAGraphRunner)
42+
runner.cache_config = ConfigBuilder.build_cache_config(engine_config)
43+
44+
assert runner._get_capture_tokens(5) == 8
45+
with pytest.raises(AssertionError):
46+
runner._get_capture_tokens(9)
47+
48+
49+
def test_graph_runner_defensively_normalizes_capture_batch_sizes():
50+
cache_config = _cache_config(max_batches=8, cudagraph_capture_batch_sizes=[1, 8])
51+
cache_config.cudagraph_capture_batch_sizes = [4, 1, 4, 16]
52+
runner = object.__new__(CUDAGraphRunner)
53+
runner.cache_config = cache_config
54+
55+
assert runner.get_capture_batch_sizes() == [1, 4, 8]

0 commit comments

Comments
 (0)