1616
1717import threading
1818
19+ import paddle
20+
1921from fastdeploy .model_executor .forward_meta import ForwardMeta
2022
2123event0 = threading .Event ()
@@ -40,31 +42,64 @@ def let_another_thread_run():
4042 GLOBAL_THREAD_INFO [thread_name ][0 ].clear ()
4143
4244
43- def split_batch_decoder_layers (forward_meta : ForwardMeta ):
44- split_num = 2
45- real_bs = forward_meta .seq_lens_this_time .shape [0 ]
45+ def is_last_thread ():
46+ thread_name = threading .current_thread ().name
4647
47- res = [ forward_meta ] * split_num
48+ return thread_name == "thread1"
4849
49- if real_bs < split_num or forward_meta .ids_remove_padding .shape [0 ] == 0 :
50- return res
5150
52- mc_bs = ( real_bs + split_num - 1 ) // split_num
51+ def creat_empty_forward_meta ( forward_meta : ForwardMeta ):
5352
54- for i in range (0 , split_num ):
55- start_bs = i * mc_bs
53+ res = ForwardMeta (
54+ ids_remove_padding = forward_meta .ids_remove_padding [0 :0 ],
55+ rotary_embs = forward_meta .rotary_embs ,
56+ attn_backend = forward_meta .attn_backend ,
57+ caches = forward_meta .caches ,
58+ )
5659
57- end_bs = start_bs + mc_bs
58- end_bs = min ( end_bs , real_bs )
60+ res . hidden_states = forward_meta . hidden_states [ 0 : 0 ]
61+ res . decode_states = forward_meta . decode_states [ 0 : 0 ]
5962
60- if start_bs >= end_bs :
61- continue
63+ return res
6264
63- start_token_id = forward_meta .cu_seqlens_q [start_bs ].item ()
64- end_token_id = forward_meta .cu_seqlens_q [end_bs ].item ()
6565
66- if start_token_id >= end_token_id :
67- continue
66+ def split_batch_decoder_layers (forward_meta : ForwardMeta , fd_config ):
67+ split_num = 2
68+ res = [creat_empty_forward_meta (forward_meta ), forward_meta ]
69+ res [0 ].tbo_microbatch_id = 0
70+ res [1 ].tbo_microbatch_id = 1
71+ total_token_num = forward_meta .ids_remove_padding .shape [0 ]
72+
73+ if total_token_num < 1024 :
74+ return res
75+
76+ chunk_token_num = (total_token_num + split_num - 1 ) // split_num
77+
78+ split_sections = []
79+ for i in range (0 , split_num ):
80+ start_token_id = i * chunk_token_num
81+ end_token_id = start_token_id + chunk_token_num
82+ end_token_id = min (total_token_num , end_token_id )
83+ split_sections .append (end_token_id )
84+
85+ # 由于多模的图片理解,需要将多模拟的token聚集在一起!
86+ # 所以需要将split_sections[0]适当的偏移一下!
87+
88+ special_tokens = [
89+ fd_config .model_config .image_patch_id ,
90+ ]
91+
92+ ids_remove_padding_cpu = forward_meta .ids_remove_padding .numpy ().tolist ()
93+ detect_pos = split_sections [0 ]
94+ while ids_remove_padding_cpu [detect_pos ] in special_tokens :
95+ detect_pos += 1
96+ if detect_pos >= len (ids_remove_padding_cpu ):
97+ return res
98+ split_sections [0 ] = detect_pos
99+
100+ for i in range (0 , split_num ):
101+ start_token_id = 0 if i == 0 else split_sections [i - 1 ]
102+ end_token_id = split_sections [i ]
68103
69104 res [i ] = ForwardMeta (
70105 ids_remove_padding = None ,
@@ -73,42 +108,62 @@ def split_batch_decoder_layers(forward_meta: ForwardMeta):
73108 caches = forward_meta .caches ,
74109 )
75110
111+ # 我们需要处理的这一段token位于[start_bs, end_bs)里面!
112+ start_bs = forward_meta .batch_id_per_token [start_token_id ]
113+ end_bs = forward_meta .batch_id_per_token [end_token_id - 1 ]
114+ end_bs += 1
115+
76116 if len (forward_meta .rotary_embs .shape ) == 6 :
77117 max_bs = forward_meta .rotary_embs .shape [0 ]
78118 assert max_bs == forward_meta .block_tables .shape [0 ]
79119 assert forward_meta .rotary_embs .shape [1 :3 ] == [2 , 1 ]
80120 assert forward_meta .rotary_embs .shape [4 ] == 1
81121 res [i ].rotary_embs = forward_meta .rotary_embs [start_bs :end_bs ]
82-
122+ res [ i ]. block_tables = forward_meta . block_tables [ start_bs : end_bs ]
83123 res [i ].ids_remove_padding = forward_meta .ids_remove_padding [start_token_id :end_token_id ]
84124 res [i ].batch_id_per_token = forward_meta .batch_id_per_token [start_token_id :end_token_id ] - start_bs
85125
86- res [i ].seq_lens_encoder = forward_meta .seq_lens_encoder [start_bs :end_bs ]
87- res [i ].seq_lens_decoder = forward_meta .seq_lens_decoder [start_bs :end_bs ]
88- res [i ].seq_lens_this_time = forward_meta .seq_lens_this_time [start_bs :end_bs ]
126+ # 下面这三个要好好弄,小心出错!
127+ # 我需要记录下 start_bs 他被left chunk 瓜分了多少了!
128+ # 我需要记录下 (end_bs-1) 他被 right chunk 瓜分了多少了!
129+ start_bs_s_token_by_left_chunk = start_token_id - forward_meta .cu_seqlens_q [start_bs ].item ()
130+ end_bs_s_token_by_right_chunk = forward_meta .cu_seqlens_q [end_bs ].item () - end_token_id
89131
90- res [i ].block_tables = forward_meta .block_tables [start_bs :end_bs ]
132+ res [i ].seq_lens_this_time = forward_meta .seq_lens_this_time [start_bs :end_bs ] + 0
133+ res [i ].seq_lens_this_time [0 ] -= start_bs_s_token_by_left_chunk
134+ res [i ].seq_lens_this_time [- 1 ] -= end_bs_s_token_by_right_chunk
135+
136+ res [i ].seq_lens_encoder = forward_meta .seq_lens_encoder [start_bs :end_bs ] + 0
137+ if res [i ].seq_lens_encoder [0 ].item () > 0 :
138+ res [i ].seq_lens_encoder [0 ] -= start_bs_s_token_by_left_chunk
139+ if res [i ].seq_lens_encoder [- 1 ].item () > 0 :
140+ res [i ].seq_lens_encoder [- 1 ] -= end_bs_s_token_by_right_chunk
141+
142+ res [i ].seq_lens_decoder = forward_meta .seq_lens_decoder [start_bs :end_bs ] + 0
143+ res [i ].seq_lens_decoder [0 ] += start_bs_s_token_by_left_chunk
144+
145+ cu_seqlens_q = [0 ] + paddle .cumsum (res [i ].seq_lens_this_time ).numpy ().tolist ()
146+ res [i ].cu_seqlens_q = paddle .to_tensor (cu_seqlens_q ).cast ("int32" )
91147
92- res [i ].cu_seqlens_q = forward_meta .cu_seqlens_q [start_bs : end_bs + 1 ] - start_token_id
93- res [i ].cu_seqlens_k = forward_meta .cu_seqlens_k [start_bs : end_bs + 1 ] - start_token_id
148+ # res[i].cu_seqlens_k = res[i].cu_seqlens_q
94149
95150 for key in GLOBAL_ATTN_BUFFERS [i ]:
96151 setattr (res [i ], key , GLOBAL_ATTN_BUFFERS [i ][key ])
97152
98153 if forward_meta .attn_mask_offsets is not None :
99154 mask_num = forward_meta .attn_mask_offsets .shape [0 ]
100- token_num = forward_meta .ids_remove_padding .shape [0 ]
101- if mask_num == token_num * 2 :
155+ if mask_num == total_token_num * 2 :
102156 res [i ].attn_mask_offsets = forward_meta .attn_mask_offsets [start_token_id * 2 : end_token_id * 2 ]
103- elif mask_num == token_num :
157+ elif mask_num == total_token_num :
104158 res [i ].attn_mask_offsets = forward_meta .attn_mask_offsets [start_token_id :end_token_id ]
105159 else :
106160 assert False , "Invalid attn_mask_offsets shape"
107161
108162 # This is adapt 5.0
109163 if hasattr (forward_meta , "hidden_states" ):
110164 res [i ].hidden_states = forward_meta .hidden_states [start_token_id :end_token_id ]
165+ # 下面这个其实不需要,因为纯文不需要这个!
111166 res [i ].decode_states = forward_meta .decode_states [start_bs :end_bs ]
112167
113- res [i ].attn_backend . init_attention_metadata ( res [ i ])
168+ res [i ].tbo_microbatch_id = i
114169 return res
0 commit comments