@@ -588,7 +588,6 @@ def reset_share_inputs(self):
588588 fill_paddle_tensor (self , "step_idx" , 0 )
589589 # fill_paddle_tensor(self, "not_need_stop", False)
590590 fill_paddle_tensor (self , "not_need_stop_device" , False )
591- fill_paddle_tensor (self , "sampled_token_ids" , - 1 )
592591 fill_paddle_tensor (self , "stop_flags" , True )
593592
594593 fill_paddle_tensor (self , "bad_tokens" , - 1 )
@@ -726,6 +725,12 @@ def reset_share_inputs(self):
726725 # Reset other miscellaneous tensors
727726 fill_paddle_tensor (self , "mask_rollback" , 0 )
728727 fill_paddle_tensor (self , "preempted_idx" , 0 )
728+ fill_paddle_tensor (self , "last_preempted_idx" , 0 )
729+
730+ # Reset tensors for overlap
731+ self .sampled_token_ids = paddle .full ([max_num_seqs , 1 ], - 1 , dtype = "int64" ).pin_memory ()
732+ self .seq_lens_this_time_cpu = paddle .full ([max_num_seqs , 1 ], 0 , dtype = "int32" ).pin_memory ()
733+ self .is_block_step_cpu = paddle .full ([max_num_seqs ], False , dtype = "bool" ).pin_memory ()
729734
730735 logger .info ("share_inputs reset completed" )
731736 except Exception as e :
0 commit comments