@@ -936,8 +936,12 @@ def reset_model_inputs(self) -> None:
936936 self .block_tables = paddle .clone (self .target_model_input_batch ["block_tables" ])
937937 self .input_ids = paddle .clone (self .target_model_input_batch ["input_ids" ])
938938 fill_paddle_tensor (self , "input_ids_cpu" , - 1 )
939- # acceptance rate decline when reset seq_lens_this_time
940- # self.seq_lens_this_time_buffer = paddle.clone(self.target_model_input_batch["seq_lens_this_time"])
939+ # NOTE(fix): Must reset seq_lens_this_time_buffer to avoid stale values from previous
940+ # RL round causing illegal memory access during CUDAGraph recapture (error 700).
941+ # When draft_model_use_cudagraph=true, padding_cudagraph_inputs() uses the full
942+ # seq_lens_this_time_buffer tensor; residual non-zero values in high-index slots
943+ # (from previous round) will make attention kernel access invalid block_table entries.
944+ fill_paddle_tensor (self , "seq_lens_this_time_buffer" , 0 )
941945
942946 self .seq_lens_encoder = paddle .clone (self .target_model_input_batch ["seq_lens_encoder" ])
943947 self .seq_lens_decoder = paddle .clone (self .target_model_input_batch ["seq_lens_decoder" ])
@@ -946,8 +950,19 @@ def reset_model_inputs(self) -> None:
946950 self .step_idx = paddle .clone (self .target_model_input_batch ["step_idx" ])
947951 self .stop_flags = paddle .clone (self .target_model_input_batch ["stop_flags" ])
948952 self .not_need_stop = paddle .to_tensor ([False ], dtype = "bool" , place = "cpu" )
953+ self .not_need_stop_device = paddle .to_tensor ([False ], dtype = "bool" )
949954 self .index_to_batch_id = {}
950955 if current_platform .is_cuda ():
956+ # NOTE(fix): These tensors get reshaped during runtime inference, so we must
957+ # recreate them at full initial size instead of cloning the (possibly resized)
958+ # target_model_input_batch tensors. Otherwise CUDAGraph replay will write
959+ # beyond tensor boundaries causing CUDA error(700).
960+ max_num_seqs = self .scheduler_config .max_num_seqs
961+ max_draft_token_num = self .speculative_config .num_speculative_tokens
962+ self .cu_seqlens_q_output = paddle .full (shape = [max_num_seqs + 1 , 1 ], fill_value = 0 , dtype = "int32" )
963+ self .batch_id_per_token_output = paddle .full (
964+ shape = [max_num_seqs * (max_draft_token_num + 1 )], fill_value = 0 , dtype = "int32"
965+ )
951966 if "token_ids_all" in self .target_model_input_batch :
952967 self .token_ids_all = paddle .clone (self .target_model_input_batch ["token_ids_all" ])
953968 # TODO: delete pre_ids in mtp
@@ -967,13 +982,24 @@ def reset_model_inputs(self) -> None:
967982 self .token_ids_all = None
968983 else :
969984 self .pre_ids = paddle .clone (self .target_model_input_batch ["pre_ids" ])
970- self .ids_remove_padding = paddle .clone (self .target_model_input_batch ["ids_remove_padding" ])
971- self .batch_id_per_token = paddle .clone (self .target_model_input_batch ["batch_id_per_token" ])
972- self .cu_seqlens_q = paddle .clone (self .target_model_input_batch ["cu_seqlens_q" ])
973- self .cu_seqlens_k = paddle .clone (self .target_model_input_batch ["cu_seqlens_k" ])
974985
975- # Reset target hidden states
976- fill_paddle_tensor (self , "target_hidden_states" , 0 )
986+ # NOTE(fix): These tensors are dynamically resized during runtime inference.
987+ # Must recreate at full initial size to avoid CUDAGraph replay OOB access.
988+ max_num_seqs = self .scheduler_config .max_num_seqs
989+ self .ids_remove_padding = paddle .full ([max_num_seqs * self .max_chunk_tokens ], 0 , dtype = "int64" )
990+ self .batch_id_per_token = paddle .full ([max_num_seqs * self .max_chunk_tokens , 1 ], 0 , dtype = "int32" )
991+ self .cu_seqlens_q = paddle .full ([max_num_seqs + 1 ], 0 , dtype = "int32" )
992+ self .cu_seqlens_k = paddle .full ([max_num_seqs + 1 ], 0 , dtype = "int32" )
993+
994+ # Reset target hidden states - must recreate at full size
995+ self .target_hidden_states = paddle .full (
996+ [
997+ self .scheduler_config .max_num_batched_tokens + self .scheduler_config .max_extra_num_batched_tokens ,
998+ self .model_config .hidden_size ,
999+ ],
1000+ 0 ,
1001+ dtype = "bfloat16" ,
1002+ )
9771003
9781004 # Reset rope embedding by recreating with default position_ids
9791005 tmp_position_ids = paddle .arange (self .model_config .max_model_len ).reshape ((1 , - 1 ))
0 commit comments