Skip to content

Commit 771d42c

Browse files
[TBO] Apply tbo to gpu_model_runner (#7165)
* apply tbo in gpu_model_runner * fix
1 parent 4cd574c commit 771d42c

2 files changed

Lines changed: 7 additions & 3 deletions

File tree

fastdeploy/worker/gpu_model_runner.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,7 @@
5656
from fastdeploy.spec_decode import SpecMethod
5757
from fastdeploy.utils import print_gpu_memory_use
5858
from fastdeploy.worker.input_batch import InputBatch, reorder_split_prefill_and_decode
59+
from fastdeploy.worker.tbo import GLOBAL_ATTN_BUFFERS
5960

6061
if current_platform.is_iluvatar():
6162
from fastdeploy.model_executor.ops.iluvatar import (
@@ -1530,7 +1531,7 @@ def _initialize_attn_backend(self) -> None:
15301531
if envs.FD_DETERMINISTIC_MODE:
15311532
decoder_block_shape_q = envs.FD_DETERMINISTIC_SPLIT_KV_SIZE
15321533

1533-
res_buffer = allocate_launch_related_buffer(
1534+
buffer_kwargs = dict(
15341535
max_batch_size=self.scheduler_config.max_num_seqs,
15351536
max_model_len=self.model_config.max_model_len,
15361537
encoder_block_shape_q=encoder_block_shape_q,
@@ -1540,8 +1541,13 @@ def _initialize_attn_backend(self) -> None:
15401541
kv_num_heads=self.model_config.kv_num_heads,
15411542
block_size=self.fd_config.cache_config.block_size,
15421543
)
1544+
res_buffer = allocate_launch_related_buffer(**buffer_kwargs)
15431545
self.share_inputs.update(res_buffer)
15441546

1547+
if int(os.getenv("USE_TBO", "0")) == 1:
1548+
for j in range(2):
1549+
GLOBAL_ATTN_BUFFERS[j] = allocate_launch_related_buffer(**buffer_kwargs)
1550+
15451551
# Get the attention backend
15461552
attn_cls = get_attention_backend()
15471553
attn_backend = attn_cls(

fastdeploy/worker/tbo.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -114,8 +114,6 @@ def split_batch_decoder_layers(forward_meta: ForwardMeta, fd_config):
114114
end_bs += 1
115115

116116
if len(forward_meta.rotary_embs.shape) == 6:
117-
max_bs = forward_meta.rotary_embs.shape[0]
118-
assert max_bs == forward_meta.block_tables.shape[0]
119117
assert forward_meta.rotary_embs.shape[1:3] == [2, 1]
120118
assert forward_meta.rotary_embs.shape[4] == 1
121119
res[i].rotary_embs = forward_meta.rotary_embs[start_bs:end_bs]

0 commit comments

Comments
 (0)