@@ -1298,16 +1298,34 @@ def initialize_attn_backend(self) -> None:
12981298 )
12991299 self .attn_backends .append (attn_backend )
13001300
1301- def _dummy_prefill_inputs (self , num_tokens : int , batch_size : int ):
1302- """Set dummy prefill inputs to share_inputs"""
1303- full_length = min (num_tokens // batch_size , self .model_config .max_model_len - 10 )
1304- input_length = int (full_length - 512 )
1301+ def get_input_length_list (self , num_tokens : int , batch_size : int , expected_decode_len : int ):
1302+ """
1303+ Args:
1304+ num_tokens (int): The total number of tokens across all sequences.
1305+ batch_size (int): The number of sequences (requests) in the batch.
1306+ expected_decode_len (int): The expected number of tokens every sequence should be generated by the model.
1307+ Returns:
1308+ List[int]: A list of integers representing the sequence length for each request.
1309+ This list is crafted to maximize the total number of blocks.
1310+ """
1311+ max_dec_len = expected_decode_len + 1
1312+ input_length = min (num_tokens // batch_size , self .model_config .max_model_len - max_dec_len )
13051313 block_num = (
13061314 input_length + self .cache_config .block_size - 1
13071315 ) // self .cache_config .block_size + self .cache_config .enc_dec_block_num
1316+ input_length_list = [input_length ] * batch_size
1317+ len_of_input_length_list = len (input_length_list )
1318+ max_dec_len_list = [max_dec_len ] * len_of_input_length_list
1319+ return input_length_list , max_dec_len_list , block_num
1320+
1321+ def _dummy_prefill_inputs (self , input_length_list : List [int ], max_dec_len_list : List [int ], block_num : int ):
1322+ """Set dummy prefill inputs to share_inputs"""
1323+ batch_size = len (input_length_list )
13081324
13091325 for i in range (batch_size ):
13101326 idx = i
1327+ input_length = input_length_list [idx ]
1328+ max_dec_len = max_dec_len_list [idx ]
13111329 self .share_inputs ["input_ids" ][idx : idx + 1 , :input_length ] = np .array ([5 ] * input_length )
13121330 self .share_inputs ["prompt_ids" ][idx : idx + 1 , :input_length ] = np .array ([5 ] * input_length )
13131331 self .share_inputs ["eos_token_id" ][:] = np .array ([2 ], dtype = "int64" ).reshape (- 1 , 1 )
@@ -1317,7 +1335,7 @@ def _dummy_prefill_inputs(self, num_tokens: int, batch_size: int):
13171335 self .share_inputs ["seq_lens_encoder" ][idx : idx + 1 ] = input_length
13181336 self .share_inputs ["seq_lens_decoder" ][idx : idx + 1 ] = 0
13191337 self .share_inputs ["step_idx" ][idx : idx + 1 ] = 0
1320- self .share_inputs ["max_dec_len" ][idx : idx + 1 ] = 10
1338+ self .share_inputs ["max_dec_len" ][idx : idx + 1 ] = max_dec_len
13211339 self .share_inputs ["stop_flags" ][idx : idx + 1 ] = False
13221340
13231341 self .share_inputs ["first_token_ids" ][idx : idx + 1 ] = self .share_inputs ["input_ids" ][idx : idx + 1 , :1 ]
@@ -1333,14 +1351,27 @@ def _dummy_run(
13331351 self ,
13341352 num_tokens : paddle .Tensor ,
13351353 batch_size : paddle .Tensor ,
1354+ expected_decode_len : int = 1 ,
13361355 in_capturing : bool = False ,
13371356 ) -> paddle .Tensor :
13381357 """
13391358 Use dummy inputs to run before formal execution.
13401359 Args:
1341- num_tokens: Expected number of tokens generated
1360+ num_tokens: Number of the input tokens
1361+ batch_size: Batch size
1362+ expected_decode_len: Expected decode length
1363+ in_capturing: Is cuda graph in capturing state
13421364 """
1343- self ._dummy_prefill_inputs (num_tokens , batch_size )
1365+ input_length_list , max_dec_len_list , block_num = self .get_input_length_list (
1366+ num_tokens = num_tokens ,
1367+ batch_size = batch_size ,
1368+ expected_decode_len = expected_decode_len ,
1369+ )
1370+ self ._dummy_prefill_inputs (
1371+ input_length_list = input_length_list ,
1372+ max_dec_len_list = max_dec_len_list ,
1373+ block_num = block_num ,
1374+ )
13441375
13451376 if self .speculative_method in ["mtp" ]:
13461377 self .proposer .dummy_prefill_inputs (
@@ -1350,7 +1381,7 @@ def _dummy_run(
13501381 )
13511382
13521383 while True :
1353- self .execute_model (is_dummy_run = True )
1384+ self .execute_model (is_dummy_run = True , in_capturing = in_capturing )
13541385
13551386 if int ((self .share_inputs ["seq_lens_this_time" ] > 0 ).sum ()) == 0 :
13561387 break
@@ -1417,6 +1448,7 @@ def execute_model(
14171448 model_forward_batch : Optional [List [Request ]] = None ,
14181449 num_running_requests : int = None ,
14191450 is_dummy_run : bool = False ,
1451+ in_capturing : bool = False ,
14201452 ) -> Optional [ModelRunnerOutput ]:
14211453 """
14221454 The Entrance of model execute.
0 commit comments