5656from fastdeploy .spec_decode import SpecMethod
5757from fastdeploy .utils import print_gpu_memory_use
5858from fastdeploy .worker .input_batch import InputBatch , reorder_split_prefill_and_decode
59+ from fastdeploy .worker .tbo import GLOBAL_ATTN_BUFFERS
5960
6061if current_platform .is_iluvatar ():
6162 from fastdeploy .model_executor .ops .iluvatar import (
@@ -1530,7 +1531,7 @@ def _initialize_attn_backend(self) -> None:
15301531 if envs .FD_DETERMINISTIC_MODE :
15311532 decoder_block_shape_q = envs .FD_DETERMINISTIC_SPLIT_KV_SIZE
15321533
1533- res_buffer = allocate_launch_related_buffer (
1534+ buffer_kwargs = dict (
15341535 max_batch_size = self .scheduler_config .max_num_seqs ,
15351536 max_model_len = self .model_config .max_model_len ,
15361537 encoder_block_shape_q = encoder_block_shape_q ,
@@ -1540,8 +1541,13 @@ def _initialize_attn_backend(self) -> None:
15401541 kv_num_heads = self .model_config .kv_num_heads ,
15411542 block_size = self .fd_config .cache_config .block_size ,
15421543 )
1544+ res_buffer = allocate_launch_related_buffer (** buffer_kwargs )
15431545 self .share_inputs .update (res_buffer )
15441546
1547+ if int (os .getenv ("USE_TBO" , "0" )) == 1 :
1548+ for j in range (2 ):
1549+ GLOBAL_ATTN_BUFFERS [j ] = allocate_launch_related_buffer (** buffer_kwargs )
1550+
15451551 # Get the attention backend
15461552 attn_cls = get_attention_backend ()
15471553 attn_backend = attn_cls (
0 commit comments