Skip to content

Commit c83dc58

Browse files
[Feature] support Two batch overlap, mainly used in Prefill (PaddlePaddle#5078)
1 parent 1aefbef commit c83dc58

4 files changed

Lines changed: 143 additions & 3 deletions

File tree

fastdeploy/model_executor/layers/embeddings.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -135,6 +135,8 @@ def __init__(
135135
self.tie_word_embeddings: bool = fd_config.model_config.tie_word_embeddings
136136
self.params_dtype: str = params_dtype
137137

138+
self.embedding_dim = embedding_dim
139+
138140
self.general = general # used for general Embedding
139141
self.num_embeddings = num_embeddings
140142
self.padding_size = padding_size
@@ -297,6 +299,8 @@ def forward(self, ids_remove_padding=None) -> paddle.Tensor:
297299
Returns:
298300
Tensor: Embedded tensor representation of the input IDs.
299301
"""
302+
if ids_remove_padding.shape[0] == 0:
303+
return paddle.empty([0, self.embedding_dim], dtype=self.embeddings.weight.dtype)
300304
if self.column_cut:
301305
input_embedings = self.embeddings(ids_remove_padding)
302306
inputs_embeds_temp = []

fastdeploy/model_executor/layers/moe/ep.py

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -505,6 +505,8 @@ class EPPrefillRunner(EPRunner):
505505
EPPrefillRunner
506506
"""
507507

508+
allocate_on_comm_stream = False
509+
508510
def __init__(
509511
self,
510512
top_k: int,
@@ -533,6 +535,12 @@ def __init__(
533535
use_internode_ll_two_stage=use_internode_ll_two_stage,
534536
)
535537

538+
def set_allocate_on_comm_stream(allocate_on_comm_stream: bool = False):
539+
logger.info(
540+
f"set allocate_on_comm_stream to {allocate_on_comm_stream}, this will force Prefill dispatch's output tensor is allocated on communication stream"
541+
)
542+
EPPrefillRunner.allocate_on_comm_stream = allocate_on_comm_stream
543+
536544
def dispatch(
537545
self,
538546
x: paddle.Tensor,
@@ -552,7 +560,13 @@ def dispatch(
552560
num_tokens_per_expert,
553561
is_token_in_rank,
554562
event,
555-
) = buffer.get_dispatch_layout(topk_idx, self.num_experts, async_finish=self.ep_engine.async_finish)
563+
) = buffer.get_dispatch_layout(
564+
topk_idx,
565+
self.num_experts,
566+
previous_event=kwargs.get("previous_event", None),
567+
allocate_on_comm_stream=EPPrefillRunner.allocate_on_comm_stream,
568+
async_finish=self.ep_engine.async_finish,
569+
)
556570

557571
x_scale_tensor = kwargs.get("x_scale_tensor", None)
558572
dispatch_args = {
@@ -566,6 +580,7 @@ def dispatch(
566580
"topk_idx": topk_idx,
567581
"topk_weights": topk_weights,
568582
"expert_alignment": expert_alignment,
583+
"allocate_on_comm_stream": EPPrefillRunner.allocate_on_comm_stream,
569584
"previous_event": event,
570585
}
571586
return buffer.dispatch(**dispatch_args)
@@ -575,6 +590,7 @@ def combine(
575590
tmp_ffn_out: paddle.Tensor,
576591
handle: tuple,
577592
recv_topk_weights: paddle.Tensor,
593+
event=None,
578594
):
579595
buffer = self.ep_engine.deepep_engine
580596
if buffer is None:
@@ -586,6 +602,7 @@ def combine(
586602
"config": self.ep_engine.ep_config,
587603
"async_finish": self.ep_engine.async_finish,
588604
"topk_weights": recv_topk_weights,
605+
"previous_event": event,
589606
}
590607
fused_moe_out, _, event = buffer.combine(**combine_args)
591608
return fused_moe_out, event

fastdeploy/model_executor/layers/moe/fused_moe_deepgemm_backend.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,11 +16,13 @@
1616

1717
import paddle
1818
from paddle import nn
19+
from paddle.distributed.communication import deep_ep
1920
from paddleformers.utils.log import logger
2021

2122
import fastdeploy
2223
from fastdeploy.model_executor.layers.utils import get_tensor
2324
from fastdeploy.model_executor.ops.gpu import count_tokens_per_expert_func, deep_gemm
25+
from fastdeploy.worker.tbo import let_another_thread_run
2426

2527
from .fused_moe_backend_base import MoEMethodBase
2628
from .fused_moe_triton_backend import BlockWiseFP8MoEMethod
@@ -142,12 +144,17 @@ def apply_ep_prefill(
142144
Apply the EP prefill method.
143145
"""
144146
gate_out = gate(x.cast("float32"))
147+
145148
# 1. Select topk experts and weights
146149
topk_idx, topk_weights = self.ep_prefill_runner.moe_select(layer, gate_out)
147150
# 2. Dynamic compute blockwise quantization scales
148151
x, x_scale_tensor = fastdeploy.model_executor.ops.gpu.per_token_quant(
149152
x, self.quant_config.weight_block_size[0]
150153
)
154+
155+
event = deep_ep.Buffer.capture()
156+
let_another_thread_run()
157+
151158
# 3. EP Dispatch
152159
(
153160
recv_x,
@@ -157,8 +164,9 @@ def apply_ep_prefill(
157164
handle,
158165
event,
159166
) = self.ep_prefill_runner.dispatch(
160-
x, topk_idx, topk_weights, x_scale_tensor=x_scale_tensor, expert_alignment=128
167+
x, topk_idx, topk_weights, x_scale_tensor=x_scale_tensor, expert_alignment=128, previous_event=event
161168
)
169+
162170
if self.ep_prefill_runner.ep_engine.async_finish:
163171
event.current_stream_wait()
164172

@@ -241,7 +249,10 @@ def apply_ep_prefill(
241249
tmp_ffn_out = paddle.cast(recv_x[0], paddle.bfloat16)
242250

243251
# 5. EP combine
244-
tmp_ffn_out, event = self.ep_prefill_runner.combine(tmp_ffn_out, handle, recv_topk_weights)
252+
event = deep_ep.Buffer.capture()
253+
let_another_thread_run()
254+
255+
tmp_ffn_out, event = self.ep_prefill_runner.combine(tmp_ffn_out, handle, recv_topk_weights, event)
245256

246257
if self.ep_prefill_runner.ep_engine.async_finish:
247258
event.current_stream_wait()

fastdeploy/worker/tbo.py

Lines changed: 108 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,108 @@
1+
"""
2+
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
"""
16+
17+
import threading
18+
19+
from fastdeploy.model_executor.forward_meta import ForwardMeta
20+
21+
event0 = threading.Event()
22+
event1 = threading.Event()
23+
24+
25+
GLOBAL_THREAD_INFO = {}
26+
27+
GLOBAL_THREAD_INFO["thread0"] = [event0, event1]
28+
GLOBAL_THREAD_INFO["thread1"] = [event1, event0]
29+
30+
31+
GLOBAL_ATTN_BUFFERS = {}
32+
33+
34+
def let_another_thread_run():
35+
thread_name = threading.current_thread().name
36+
37+
if thread_name in GLOBAL_THREAD_INFO:
38+
GLOBAL_THREAD_INFO[thread_name][1].set()
39+
GLOBAL_THREAD_INFO[thread_name][0].wait()
40+
GLOBAL_THREAD_INFO[thread_name][0].clear()
41+
42+
43+
def split_batch_decoder_layers(forward_meta: ForwardMeta):
44+
split_num = 2
45+
real_bs = forward_meta.seq_lens_this_time.shape[0]
46+
47+
res = [forward_meta] * split_num
48+
49+
if real_bs < split_num or forward_meta.ids_remove_padding.shape[0] == 0:
50+
return res
51+
52+
mc_bs = (real_bs + split_num - 1) // split_num
53+
54+
for i in range(0, split_num):
55+
start_bs = i * mc_bs
56+
57+
end_bs = start_bs + mc_bs
58+
end_bs = min(end_bs, real_bs)
59+
60+
if start_bs >= end_bs:
61+
continue
62+
63+
start_token_id = forward_meta.cu_seqlens_q[start_bs].item()
64+
end_token_id = forward_meta.cu_seqlens_q[end_bs].item()
65+
66+
if start_token_id >= end_token_id:
67+
continue
68+
69+
res[i] = ForwardMeta(
70+
ids_remove_padding=None,
71+
rotary_embs=forward_meta.rotary_embs,
72+
attn_backend=forward_meta.attn_backend,
73+
caches=forward_meta.caches,
74+
)
75+
76+
res[i].rotary_embs = forward_meta.rotary_embs[start_bs:end_bs]
77+
78+
res[i].ids_remove_padding = forward_meta.ids_remove_padding[start_token_id:end_token_id]
79+
res[i].batch_id_per_token = forward_meta.batch_id_per_token[start_token_id:end_token_id] - start_bs
80+
81+
res[i].seq_lens_encoder = forward_meta.seq_lens_encoder[start_bs:end_bs]
82+
res[i].seq_lens_decoder = forward_meta.seq_lens_decoder[start_bs:end_bs]
83+
res[i].seq_lens_this_time = forward_meta.seq_lens_this_time[start_bs:end_bs]
84+
85+
res[i].block_tables = forward_meta.block_tables[start_bs:end_bs]
86+
87+
res[i].cu_seqlens_q = forward_meta.cu_seqlens_q[start_bs : end_bs + 1] - start_token_id
88+
res[i].cu_seqlens_k = forward_meta.cu_seqlens_k[start_bs : end_bs + 1] - start_token_id
89+
90+
for key in GLOBAL_ATTN_BUFFERS[i]:
91+
setattr(res[i], key, GLOBAL_ATTN_BUFFERS[i][key])
92+
93+
if forward_meta.attn_mask_offsets is not None:
94+
mask_num = forward_meta.attn_mask_offsets.shape[0]
95+
token_num = forward_meta.ids_remove_padding.shape[0]
96+
if mask_num == token_num * 2:
97+
res[i].attn_mask_offsets = forward_meta.attn_mask_offsets[start_token_id * 2 : end_token_id * 2]
98+
elif mask_num == token_num:
99+
res[i].attn_mask_offsets = forward_meta.attn_mask_offsets[start_token_id:end_token_id]
100+
else:
101+
assert False, "Invalid attn_mask_offsets shape"
102+
103+
# This is to adapt 5
104+
if hasattr(forward_meta, "hidden_states"):
105+
res[i].hidden_states = forward_meta.hidden_states[start_token_id:end_token_id]
106+
res[i].decode_states = forward_meta.decode_states[start_bs:end_bs]
107+
108+
return res

0 commit comments

Comments
 (0)