Skip to content

Commit 85d9951

Browse files
authored
Update Dummy Run To Suppport Mutil-Batch Execution (PaddlePaddle#6123)
1 parent 9ee0156 commit 85d9951

1 file changed

Lines changed: 40 additions & 8 deletions

File tree

fastdeploy/worker/xpu_model_runner.py

Lines changed: 40 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1298,16 +1298,34 @@ def initialize_attn_backend(self) -> None:
12981298
)
12991299
self.attn_backends.append(attn_backend)
13001300

1301-
def _dummy_prefill_inputs(self, num_tokens: int, batch_size: int):
1302-
"""Set dummy prefill inputs to share_inputs"""
1303-
full_length = min(num_tokens // batch_size, self.model_config.max_model_len - 10)
1304-
input_length = int(full_length - 512)
1301+
def get_input_length_list(self, num_tokens: int, batch_size: int, expected_decode_len: int):
1302+
"""
1303+
Args:
1304+
num_tokens (int): The total number of tokens across all sequences.
1305+
batch_size (int): The number of sequences (requests) in the batch.
1306+
expected_decode_len (int): The expected number of tokens every sequence should be generated by the model.
1307+
Returns:
1308+
List[int]: A list of integers representing the sequence length for each request.
1309+
This list is crafted to maximize the total number of blocks.
1310+
"""
1311+
max_dec_len = expected_decode_len + 1
1312+
input_length = min(num_tokens // batch_size, self.model_config.max_model_len - max_dec_len)
13051313
block_num = (
13061314
input_length + self.cache_config.block_size - 1
13071315
) // self.cache_config.block_size + self.cache_config.enc_dec_block_num
1316+
input_length_list = [input_length] * batch_size
1317+
len_of_input_length_list = len(input_length_list)
1318+
max_dec_len_list = [max_dec_len] * len_of_input_length_list
1319+
return input_length_list, max_dec_len_list, block_num
1320+
1321+
def _dummy_prefill_inputs(self, input_length_list: List[int], max_dec_len_list: List[int], block_num: int):
1322+
"""Set dummy prefill inputs to share_inputs"""
1323+
batch_size = len(input_length_list)
13081324

13091325
for i in range(batch_size):
13101326
idx = i
1327+
input_length = input_length_list[idx]
1328+
max_dec_len = max_dec_len_list[idx]
13111329
self.share_inputs["input_ids"][idx : idx + 1, :input_length] = np.array([5] * input_length)
13121330
self.share_inputs["prompt_ids"][idx : idx + 1, :input_length] = np.array([5] * input_length)
13131331
self.share_inputs["eos_token_id"][:] = np.array([2], dtype="int64").reshape(-1, 1)
@@ -1317,7 +1335,7 @@ def _dummy_prefill_inputs(self, num_tokens: int, batch_size: int):
13171335
self.share_inputs["seq_lens_encoder"][idx : idx + 1] = input_length
13181336
self.share_inputs["seq_lens_decoder"][idx : idx + 1] = 0
13191337
self.share_inputs["step_idx"][idx : idx + 1] = 0
1320-
self.share_inputs["max_dec_len"][idx : idx + 1] = 10
1338+
self.share_inputs["max_dec_len"][idx : idx + 1] = max_dec_len
13211339
self.share_inputs["stop_flags"][idx : idx + 1] = False
13221340

13231341
self.share_inputs["first_token_ids"][idx : idx + 1] = self.share_inputs["input_ids"][idx : idx + 1, :1]
@@ -1333,14 +1351,27 @@ def _dummy_run(
13331351
self,
13341352
num_tokens: paddle.Tensor,
13351353
batch_size: paddle.Tensor,
1354+
expected_decode_len: int = 1,
13361355
in_capturing: bool = False,
13371356
) -> paddle.Tensor:
13381357
"""
13391358
Use dummy inputs to run before formal execution.
13401359
Args:
1341-
num_tokens: Expected number of tokens generated
1360+
num_tokens: Number of the input tokens
1361+
batch_size: Batch size
1362+
expected_decode_len: Expected decode length
1363+
in_capturing: Is cuda graph in capturing state
13421364
"""
1343-
self._dummy_prefill_inputs(num_tokens, batch_size)
1365+
input_length_list, max_dec_len_list, block_num = self.get_input_length_list(
1366+
num_tokens=num_tokens,
1367+
batch_size=batch_size,
1368+
expected_decode_len=expected_decode_len,
1369+
)
1370+
self._dummy_prefill_inputs(
1371+
input_length_list=input_length_list,
1372+
max_dec_len_list=max_dec_len_list,
1373+
block_num=block_num,
1374+
)
13441375

13451376
if self.speculative_method in ["mtp"]:
13461377
self.proposer.dummy_prefill_inputs(
@@ -1350,7 +1381,7 @@ def _dummy_run(
13501381
)
13511382

13521383
while True:
1353-
self.execute_model(is_dummy_run=True)
1384+
self.execute_model(is_dummy_run=True, in_capturing=in_capturing)
13541385

13551386
if int((self.share_inputs["seq_lens_this_time"] > 0).sum()) == 0:
13561387
break
@@ -1417,6 +1448,7 @@ def execute_model(
14171448
model_forward_batch: Optional[List[Request]] = None,
14181449
num_running_requests: int = None,
14191450
is_dummy_run: bool = False,
1451+
in_capturing: bool = False,
14201452
) -> Optional[ModelRunnerOutput]:
14211453
"""
14221454
The Entrance of model execute.

0 commit comments

Comments
 (0)