Skip to content

Commit cbdb246

Browse files
cp 1131 tbo to develop (PaddlePaddle#6281)
1 parent 8277b95 commit cbdb246

3 files changed

Lines changed: 87 additions & 31 deletions

File tree

fastdeploy/model_executor/layers/moe/ep.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -560,12 +560,13 @@ def create_deep_ep_buffer(self):
560560

561561

562562
class EPPrefillRunner(EPRunner):
563+
564+
allocate_on_comm_stream = False
565+
563566
"""
564567
EPPrefillRunner
565568
"""
566569

567-
allocate_on_comm_stream = False
568-
569570
def __init__(
570571
self,
571572
top_k: int,
@@ -664,6 +665,7 @@ def combine(
664665
"async_finish": self.ep_engine.async_finish,
665666
"topk_weights": recv_topk_weights,
666667
"previous_event": event,
668+
"allocate_on_comm_stream": EPPrefillRunner.allocate_on_comm_stream,
667669
}
668670
fused_moe_out, _, event = buffer.combine(**combine_args)
669671
return fused_moe_out, event

fastdeploy/model_executor/layers/moe/fused_moe_deepgemm_backend.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -295,7 +295,6 @@ def apply_ep_prefill(
295295
token_all_num,
296296
)
297297
assert permute_input.shape[0] == token_all_num
298-
del recv_x
299298

300299
permute_scale = permute_scale.transpose([1, 0]).contiguous().transpose([1, 0])
301300

fastdeploy/worker/tbo.py

Lines changed: 83 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@
1616

1717
import threading
1818

19+
import paddle
20+
1921
from fastdeploy.model_executor.forward_meta import ForwardMeta
2022

2123
event0 = threading.Event()
@@ -40,31 +42,64 @@ def let_another_thread_run():
4042
GLOBAL_THREAD_INFO[thread_name][0].clear()
4143

4244

43-
def split_batch_decoder_layers(forward_meta: ForwardMeta):
44-
split_num = 2
45-
real_bs = forward_meta.seq_lens_this_time.shape[0]
45+
def is_last_thread():
46+
thread_name = threading.current_thread().name
4647

47-
res = [forward_meta] * split_num
48+
return thread_name == "thread1"
4849

49-
if real_bs < split_num or forward_meta.ids_remove_padding.shape[0] == 0:
50-
return res
5150

52-
mc_bs = (real_bs + split_num - 1) // split_num
51+
def creat_empty_forward_meta(forward_meta: ForwardMeta):
5352

54-
for i in range(0, split_num):
55-
start_bs = i * mc_bs
53+
res = ForwardMeta(
54+
ids_remove_padding=forward_meta.ids_remove_padding[0:0],
55+
rotary_embs=forward_meta.rotary_embs,
56+
attn_backend=forward_meta.attn_backend,
57+
caches=forward_meta.caches,
58+
)
5659

57-
end_bs = start_bs + mc_bs
58-
end_bs = min(end_bs, real_bs)
60+
res.hidden_states = forward_meta.hidden_states[0:0]
61+
res.decode_states = forward_meta.decode_states[0:0]
5962

60-
if start_bs >= end_bs:
61-
continue
63+
return res
6264

63-
start_token_id = forward_meta.cu_seqlens_q[start_bs].item()
64-
end_token_id = forward_meta.cu_seqlens_q[end_bs].item()
6565

66-
if start_token_id >= end_token_id:
67-
continue
66+
def split_batch_decoder_layers(forward_meta: ForwardMeta, fd_config):
67+
split_num = 2
68+
res = [creat_empty_forward_meta(forward_meta), forward_meta]
69+
res[0].tbo_microbatch_id = 0
70+
res[1].tbo_microbatch_id = 1
71+
total_token_num = forward_meta.ids_remove_padding.shape[0]
72+
73+
if total_token_num < 1024:
74+
return res
75+
76+
chunk_token_num = (total_token_num + split_num - 1) // split_num
77+
78+
split_sections = []
79+
for i in range(0, split_num):
80+
start_token_id = i * chunk_token_num
81+
end_token_id = start_token_id + chunk_token_num
82+
end_token_id = min(total_token_num, end_token_id)
83+
split_sections.append(end_token_id)
84+
85+
# 由于多模的图片理解,需要将多模拟的token聚集在一起!
86+
# 所以需要将split_sections[0]适当的偏移一下!
87+
88+
special_tokens = [
89+
fd_config.model_config.image_patch_id,
90+
]
91+
92+
ids_remove_padding_cpu = forward_meta.ids_remove_padding.numpy().tolist()
93+
detect_pos = split_sections[0]
94+
while ids_remove_padding_cpu[detect_pos] in special_tokens:
95+
detect_pos += 1
96+
if detect_pos >= len(ids_remove_padding_cpu):
97+
return res
98+
split_sections[0] = detect_pos
99+
100+
for i in range(0, split_num):
101+
start_token_id = 0 if i == 0 else split_sections[i - 1]
102+
end_token_id = split_sections[i]
68103

69104
res[i] = ForwardMeta(
70105
ids_remove_padding=None,
@@ -73,42 +108,62 @@ def split_batch_decoder_layers(forward_meta: ForwardMeta):
73108
caches=forward_meta.caches,
74109
)
75110

111+
# 我们需要处理的这一段token位于[start_bs, end_bs)里面!
112+
start_bs = forward_meta.batch_id_per_token[start_token_id]
113+
end_bs = forward_meta.batch_id_per_token[end_token_id - 1]
114+
end_bs += 1
115+
76116
if len(forward_meta.rotary_embs.shape) == 6:
77117
max_bs = forward_meta.rotary_embs.shape[0]
78118
assert max_bs == forward_meta.block_tables.shape[0]
79119
assert forward_meta.rotary_embs.shape[1:3] == [2, 1]
80120
assert forward_meta.rotary_embs.shape[4] == 1
81121
res[i].rotary_embs = forward_meta.rotary_embs[start_bs:end_bs]
82-
122+
res[i].block_tables = forward_meta.block_tables[start_bs:end_bs]
83123
res[i].ids_remove_padding = forward_meta.ids_remove_padding[start_token_id:end_token_id]
84124
res[i].batch_id_per_token = forward_meta.batch_id_per_token[start_token_id:end_token_id] - start_bs
85125

86-
res[i].seq_lens_encoder = forward_meta.seq_lens_encoder[start_bs:end_bs]
87-
res[i].seq_lens_decoder = forward_meta.seq_lens_decoder[start_bs:end_bs]
88-
res[i].seq_lens_this_time = forward_meta.seq_lens_this_time[start_bs:end_bs]
126+
# 下面这三个要好好弄,小心出错!
127+
# 我需要记录下 start_bs 他被left chunk 瓜分了多少了!
128+
# 我需要记录下 (end_bs-1) 他被 right chunk 瓜分了多少了!
129+
start_bs_s_token_by_left_chunk = start_token_id - forward_meta.cu_seqlens_q[start_bs].item()
130+
end_bs_s_token_by_right_chunk = forward_meta.cu_seqlens_q[end_bs].item() - end_token_id
89131

90-
res[i].block_tables = forward_meta.block_tables[start_bs:end_bs]
132+
res[i].seq_lens_this_time = forward_meta.seq_lens_this_time[start_bs:end_bs] + 0
133+
res[i].seq_lens_this_time[0] -= start_bs_s_token_by_left_chunk
134+
res[i].seq_lens_this_time[-1] -= end_bs_s_token_by_right_chunk
135+
136+
res[i].seq_lens_encoder = forward_meta.seq_lens_encoder[start_bs:end_bs] + 0
137+
if res[i].seq_lens_encoder[0].item() > 0:
138+
res[i].seq_lens_encoder[0] -= start_bs_s_token_by_left_chunk
139+
if res[i].seq_lens_encoder[-1].item() > 0:
140+
res[i].seq_lens_encoder[-1] -= end_bs_s_token_by_right_chunk
141+
142+
res[i].seq_lens_decoder = forward_meta.seq_lens_decoder[start_bs:end_bs] + 0
143+
res[i].seq_lens_decoder[0] += start_bs_s_token_by_left_chunk
144+
145+
cu_seqlens_q = [0] + paddle.cumsum(res[i].seq_lens_this_time).numpy().tolist()
146+
res[i].cu_seqlens_q = paddle.to_tensor(cu_seqlens_q).cast("int32")
91147

92-
res[i].cu_seqlens_q = forward_meta.cu_seqlens_q[start_bs : end_bs + 1] - start_token_id
93-
res[i].cu_seqlens_k = forward_meta.cu_seqlens_k[start_bs : end_bs + 1] - start_token_id
148+
# res[i].cu_seqlens_k = res[i].cu_seqlens_q
94149

95150
for key in GLOBAL_ATTN_BUFFERS[i]:
96151
setattr(res[i], key, GLOBAL_ATTN_BUFFERS[i][key])
97152

98153
if forward_meta.attn_mask_offsets is not None:
99154
mask_num = forward_meta.attn_mask_offsets.shape[0]
100-
token_num = forward_meta.ids_remove_padding.shape[0]
101-
if mask_num == token_num * 2:
155+
if mask_num == total_token_num * 2:
102156
res[i].attn_mask_offsets = forward_meta.attn_mask_offsets[start_token_id * 2 : end_token_id * 2]
103-
elif mask_num == token_num:
157+
elif mask_num == total_token_num:
104158
res[i].attn_mask_offsets = forward_meta.attn_mask_offsets[start_token_id:end_token_id]
105159
else:
106160
assert False, "Invalid attn_mask_offsets shape"
107161

108162
# This is adapt 5.0
109163
if hasattr(forward_meta, "hidden_states"):
110164
res[i].hidden_states = forward_meta.hidden_states[start_token_id:end_token_id]
165+
# 下面这个其实不需要,因为纯文不需要这个!
111166
res[i].decode_states = forward_meta.decode_states[start_bs:end_bs]
112167

113-
res[i].attn_backend.init_attention_metadata(res[i])
168+
res[i].tbo_microbatch_id = i
114169
return res

0 commit comments

Comments
 (0)