Skip to content

Commit 9f33332

Browse files
Tsundoku958maoruihan
andauthored
lmdeploy support kernel block size (#4421)
* add kernel block config * support kernel block size * add comment * fix * fix format * move cal kernel offs to _tensorlize_block_offsets * fix * add map kernel offs func --------- Co-authored-by: maoruihan <maoruihan@stonewise.cn>
1 parent cc51fc8 commit 9f33332

8 files changed

Lines changed: 66 additions & 7 deletions

File tree

lmdeploy/cli/serve.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,7 @@ def add_parser_api_server():
106106
ArgumentHelper.dllm_confidence_threshold(pt_group)
107107
ArgumentHelper.enable_return_routed_experts(pt_group)
108108
ArgumentHelper.distributed_executor_backend(pt_group)
109+
ArgumentHelper.kernel_block_size(pt_group)
109110

110111
# common engine args
111112
dtype_act = ArgumentHelper.dtype(pt_group)
@@ -226,6 +227,7 @@ def api_server(args):
226227
max_batch_size=max_batch_size,
227228
cache_max_entry_count=args.cache_max_entry_count,
228229
block_size=args.cache_block_seq_len,
230+
kernel_block_size=args.kernel_block_size,
229231
session_len=args.session_len,
230232
adapters=adapters,
231233
enable_prefix_caching=args.enable_prefix_caching,

lmdeploy/cli/utils.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -546,6 +546,19 @@ def cache_block_seq_len(parser):
546546
'if Lora Adapter is specified, this parameter will '
547547
'be ignored')
548548

549+
@staticmethod
550+
def kernel_block_size(parser):
551+
"""Add argument kernel_block_size to parser."""
552+
553+
return parser.add_argument('--kernel-block-size',
554+
type=int,
555+
default=-1,
556+
help='The length of the token sequence in a k/v block for kernels. '
557+
'Only supported by Pytorch Engine. '
558+
'When set to a different value than --cache-block-seq-len, '
559+
'memory allocators and prefix cache use --cache-block-seq-len '
560+
'as the block size, while kernels use --kernel-block-size.')
561+
549562
@staticmethod
550563
def enable_prefix_caching(parser):
551564
"""Add argument enable_prefix_caching to parser."""

lmdeploy/messages.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -392,6 +392,7 @@ class PytorchEngineConfig:
392392
cache_max_entry_count: float = 0.8
393393
prefill_interval: int = 16
394394
block_size: int = 64
395+
kernel_block_size: int = -1
395396
num_cpu_blocks: int = 0
396397
num_gpu_blocks: int = 0
397398
adapters: dict[str, str] = None
@@ -430,6 +431,8 @@ class PytorchEngineConfig:
430431

431432
def __post_init__(self):
432433
"""Check input validation."""
434+
if self.kernel_block_size == -1:
435+
self.kernel_block_size = self.block_size
433436
assert self.dtype in ['auto', 'float16', 'bfloat16']
434437
assert self.tp >= 1, 'invalid tp'
435438
assert self.dp >= 1, 'invalid dp'
@@ -442,8 +445,14 @@ def __post_init__(self):
442445
assert self.num_gpu_blocks >= 0, 'invalid num_gpu_blocks'
443446
assert self.quant_policy in (0, 4, 8), 'invalid quant_policy'
444447
assert self.device_type in ['cuda', 'ascend', 'maca', 'camb'], (f'invalid device_type: {self.device_type}')
445-
assert self.block_size >= 16 and (self.block_size & (self.block_size - 1)) == 0, \
446-
f'block_size must be >= 16 and a power of 2, but got {self.block_size}'
448+
assert self.kernel_block_size >= 16 and \
449+
(self.kernel_block_size & (self.kernel_block_size - 1)) == 0, \
450+
f'kernel_block_size must be >= 16 and a power of 2, but got {self.kernel_block_size}'
451+
assert self.block_size >= self.kernel_block_size and \
452+
self.block_size % self.kernel_block_size == 0, \
453+
(f'block_size must be >= kernel_block_size and an integer multiple '
454+
f'of kernel_block_size, but got block_size {self.block_size} '
455+
f'and kernel_block_size {self.kernel_block_size}')
447456
if self.quant_policy > 0 and self.device_type not in ['cuda', 'ascend']:
448457
assert False, \
449458
'kv cache quantization only works for CUDA and ASCEND.'

lmdeploy/pytorch/config.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,7 @@ class CacheConfig:
9494
block_size: int
9595
num_cpu_blocks: int
9696
num_gpu_blocks: int
97+
kernel_block_size: int = -1
9798
window_size: int = -1
9899
cache_max_entry_count: float = 0.8
99100
max_prefill_token_num: int = 4096
@@ -115,6 +116,8 @@ def __post_init__(self):
115116
if self.window_size > 1 and self.enable_prefix_caching:
116117
logger.warning('Prefix caching is not available for window attention.')
117118
self.enable_prefix_caching = False
119+
if self.kernel_block_size == -1:
120+
self.kernel_block_size = self.block_size
118121

119122

120123
class TPMode(enum.Enum):

lmdeploy/pytorch/engine/cache_engine.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,7 @@ def __init__(
8484
self.cache_config = cache_config
8585
self.model_config = model_config
8686

87-
self.block_size = cache_config.block_size
87+
self.block_size = cache_config.kernel_block_size
8888
self.num_layers = model_config.num_layers
8989
self.kv_cache_dtype = _get_kv_cache_dtype(self.model_config)
9090

@@ -198,7 +198,7 @@ def get_k_cache_desc(cls, model_config: ModelConfig, cache_config: CacheConfig,
198198
head_size = model_config.head_dim
199199
shape = cls._get_key_block_shape_impl(
200200
model_config,
201-
block_size=cache_config.block_size,
201+
block_size=cache_config.kernel_block_size,
202202
head_size=head_size,
203203
world_size=world_size,
204204
quant_policy=cache_config.quant_policy,
@@ -217,7 +217,7 @@ def get_v_cache_desc(cls, model_config: ModelConfig, cache_config: CacheConfig,
217217
head_size = model_config.head_dim
218218
shape = cls._get_value_block_shape_impl(
219219
model_config,
220-
block_size=cache_config.block_size,
220+
block_size=cache_config.kernel_block_size,
221221
head_size=head_size,
222222
world_size=world_size,
223223
quant_policy=cache_config.quant_policy,
@@ -248,7 +248,7 @@ def get_custom_cache_descs(cls, model_config: ModelConfig, cache_config: CacheCo
248248
if len(model_config.cache_shapes) == 0:
249249
return []
250250

251-
block_size = cache_config.block_size
251+
block_size = cache_config.kernel_block_size
252252

253253
descs = []
254254
for shape, dtype in model_config.cache_shapes:
@@ -263,6 +263,8 @@ def allocate_caches(cls, num_blocks: int, model_config: ModelConfig, cache_confi
263263
"""Allocate caches."""
264264

265265
num_layers = model_config.num_layers
266+
kernel_blocks_per_kv = cache_config.block_size // cache_config.kernel_block_size
267+
num_blocks *= kernel_blocks_per_kv
266268

267269
# get all descs
268270
k_cache_desc = cls.get_k_cache_desc(model_config, cache_config, world_size)

lmdeploy/pytorch/engine/config_builder.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,7 @@ def build_cache_config(engine_config: PytorchEngineConfig):
6262
cache_config = CacheConfig(
6363
max_batches=engine_config.max_batch_size,
6464
block_size=engine_config.block_size,
65+
kernel_block_size=engine_config.kernel_block_size,
6566
num_cpu_blocks=engine_config.num_cpu_blocks,
6667
num_gpu_blocks=engine_config.num_gpu_blocks,
6768
cache_max_entry_count=engine_config.cache_max_entry_count,

lmdeploy/pytorch/engine/inputs_maker.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -214,6 +214,9 @@ def __init__(
214214
self.adapter_manager = adapter_manager
215215
self.config = config
216216
self.spec_decoding = config.spec_decoding
217+
self.cache_config = scheduler.cache_config
218+
self.kernel_blocks_per_kv = self.cache_config.block_size // self.cache_config.kernel_block_size
219+
self.kernel_block_arange = torch.arange(self.kernel_blocks_per_kv, dtype=self.torch_int_dtype)
217220

218221
# strategies
219222
self.engine_strategy = engine_strategy
@@ -322,6 +325,29 @@ def _set_adapter_ids(self, model_inputs: ModelInputs, messages: 'SeqList'):
322325
local_adapter_ids = model_inputs.seq_length.new_tensor(local_adapter_ids)
323326
model_inputs.local_adapter_ids = local_adapter_ids
324327

328+
def _map_to_kernel_block_offsets(self, block_offsets: torch.Tensor):
329+
"""Converts manager block_offsets to kernel block_offsets.
330+
331+
Example:
332+
333+
# block_manager block size: 32 tokens,
334+
# Kernel block size: 16 tokens
335+
# kernel_blocks_per_kv = 2
336+
>>> block_manager block offsets = [0, 1, 3]
337+
>>> Result kernel block offsets = [0, 1, 2, 3, 6, 7]
338+
339+
# Each block_manager block id maps to 2 kernel block id:
340+
# block_manager block id 0 -> kernel block id [0, 1]
341+
# block_manager block id 1 -> kernel block id [2, 3]
342+
# block_manager block id 3 -> kernel block id [6, 7]
343+
"""
344+
if self.kernel_blocks_per_kv == 1:
345+
return block_offsets
346+
batch_size = block_offsets.shape[0]
347+
block_offsets = (block_offsets[:, :, None] * self.kernel_blocks_per_kv +
348+
self.kernel_block_arange[None, None, :]).reshape(batch_size, -1)
349+
return block_offsets
350+
325351
@torch.inference_mode()
326352
@record_function('create_model_inputs')
327353
def create_model_inputs(self, messages: 'SeqList', is_prefill: bool):
@@ -355,6 +381,7 @@ def create_model_inputs(self, messages: 'SeqList', is_prefill: bool):
355381
# block offsets
356382
block_offsets = self.scheduler.get_block_tables(messages)
357383
block_offsets = _tensorlize_block_offsets(block_offsets, dtype=self.torch_int_dtype)
384+
block_offsets = self._map_to_kernel_block_offsets(block_offsets)
358385

359386
# num_ignored_history
360387
num_ignored_history = torch.tensor([msg.num_ignored_history for msg in messages])
@@ -410,6 +437,7 @@ def create_model_inputs_long_context(self,
410437
# block offsets
411438
block_offsets = self.scheduler.get_block_tables([seq])
412439
block_offsets = torch.as_tensor(block_offsets[0], dtype=self.torch_int_dtype)[None]
440+
block_offsets = self._map_to_kernel_block_offsets(block_offsets)
413441

414442
# num_ignored_history
415443
num_ignored_history = torch.tensor([seq.num_ignored_history])
@@ -482,6 +510,7 @@ def create_model_inputs_delta(self):
482510
# block offsets
483511
block_offsets = self.scheduler.get_block_tables(valid_seqs)
484512
block_offsets = _tensorlize_block_offsets(block_offsets, dtype=self.torch_int_dtype)
513+
block_offsets = self._map_to_kernel_block_offsets(block_offsets)
485514

486515
# sliding window
487516
if self.scheduler.cache_config.window_size > 0:

lmdeploy/pytorch/paging/scheduler.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -353,7 +353,7 @@ def has_unfinished(self):
353353
return self.has_ready() or self.has_waiting() or self.has_migration_done()
354354

355355
def get_block_tables(self, seqs: SeqList):
356-
"""Get block table of the sequences."""
356+
"""Get block tables for the sequences."""
357357
return [self.block_manager.get_block_table(seq) for seq in seqs]
358358

359359
def evict_seqs(self, running: SeqList):

0 commit comments

Comments
 (0)