Skip to content

Commit 261041b

Browse files
authored
[Cherry-Pick][Bugfix] Fix clear bug in RL causing CUDA error 700 during CUDAGraph recapture(#7934) (#7933)
* fix clear bug in rl * fix: use self.max_chunk_tokens instead of fd_config.get_max_chunk_tokens() for buffer recreation fd_config.get_max_chunk_tokens() without mm_max_tokens_per_item arg may return a smaller value than the actual initial buffer size when enable_mm and mm_max_tokens_per_item is None. Use self.max_chunk_tokens which is already computed during __init__ and consistent with first CUDAGraph capture.
1 parent c52b063 commit 261041b

1 file changed

Lines changed: 34 additions & 8 deletions

File tree

fastdeploy/worker/input_batch.py

Lines changed: 34 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -936,8 +936,12 @@ def reset_model_inputs(self) -> None:
936936
self.block_tables = paddle.clone(self.target_model_input_batch["block_tables"])
937937
self.input_ids = paddle.clone(self.target_model_input_batch["input_ids"])
938938
fill_paddle_tensor(self, "input_ids_cpu", -1)
939-
# acceptance rate decline when reset seq_lens_this_time
940-
# self.seq_lens_this_time_buffer = paddle.clone(self.target_model_input_batch["seq_lens_this_time"])
939+
# NOTE(fix): Must reset seq_lens_this_time_buffer to avoid stale values from previous
940+
# RL round causing illegal memory access during CUDAGraph recapture (error 700).
941+
# When draft_model_use_cudagraph=true, padding_cudagraph_inputs() uses the full
942+
# seq_lens_this_time_buffer tensor; residual non-zero values in high-index slots
943+
# (from previous round) will make attention kernel access invalid block_table entries.
944+
fill_paddle_tensor(self, "seq_lens_this_time_buffer", 0)
941945

942946
self.seq_lens_encoder = paddle.clone(self.target_model_input_batch["seq_lens_encoder"])
943947
self.seq_lens_decoder = paddle.clone(self.target_model_input_batch["seq_lens_decoder"])
@@ -946,8 +950,19 @@ def reset_model_inputs(self) -> None:
946950
self.step_idx = paddle.clone(self.target_model_input_batch["step_idx"])
947951
self.stop_flags = paddle.clone(self.target_model_input_batch["stop_flags"])
948952
self.not_need_stop = paddle.to_tensor([False], dtype="bool", place="cpu")
953+
self.not_need_stop_device = paddle.to_tensor([False], dtype="bool")
949954
self.index_to_batch_id = {}
950955
if current_platform.is_cuda():
956+
# NOTE(fix): These tensors get reshaped during runtime inference, so we must
957+
# recreate them at full initial size instead of cloning the (possibly resized)
958+
# target_model_input_batch tensors. Otherwise CUDAGraph replay will write
959+
# beyond tensor boundaries causing CUDA error(700).
960+
max_num_seqs = self.scheduler_config.max_num_seqs
961+
max_draft_token_num = self.speculative_config.num_speculative_tokens
962+
self.cu_seqlens_q_output = paddle.full(shape=[max_num_seqs + 1, 1], fill_value=0, dtype="int32")
963+
self.batch_id_per_token_output = paddle.full(
964+
shape=[max_num_seqs * (max_draft_token_num + 1)], fill_value=0, dtype="int32"
965+
)
951966
if "token_ids_all" in self.target_model_input_batch:
952967
self.token_ids_all = paddle.clone(self.target_model_input_batch["token_ids_all"])
953968
# TODO: delete pre_ids in mtp
@@ -967,13 +982,24 @@ def reset_model_inputs(self) -> None:
967982
self.token_ids_all = None
968983
else:
969984
self.pre_ids = paddle.clone(self.target_model_input_batch["pre_ids"])
970-
self.ids_remove_padding = paddle.clone(self.target_model_input_batch["ids_remove_padding"])
971-
self.batch_id_per_token = paddle.clone(self.target_model_input_batch["batch_id_per_token"])
972-
self.cu_seqlens_q = paddle.clone(self.target_model_input_batch["cu_seqlens_q"])
973-
self.cu_seqlens_k = paddle.clone(self.target_model_input_batch["cu_seqlens_k"])
974985

975-
# Reset target hidden states
976-
fill_paddle_tensor(self, "target_hidden_states", 0)
986+
# NOTE(fix): These tensors are dynamically resized during runtime inference.
987+
# Must recreate at full initial size to avoid CUDAGraph replay OOB access.
988+
max_num_seqs = self.scheduler_config.max_num_seqs
989+
self.ids_remove_padding = paddle.full([max_num_seqs * self.max_chunk_tokens], 0, dtype="int64")
990+
self.batch_id_per_token = paddle.full([max_num_seqs * self.max_chunk_tokens, 1], 0, dtype="int32")
991+
self.cu_seqlens_q = paddle.full([max_num_seqs + 1], 0, dtype="int32")
992+
self.cu_seqlens_k = paddle.full([max_num_seqs + 1], 0, dtype="int32")
993+
994+
# Reset target hidden states - must recreate at full size
995+
self.target_hidden_states = paddle.full(
996+
[
997+
self.scheduler_config.max_num_batched_tokens + self.scheduler_config.max_extra_num_batched_tokens,
998+
self.model_config.hidden_size,
999+
],
1000+
0,
1001+
dtype="bfloat16",
1002+
)
9771003

9781004
# Reset rope embedding by recreating with default position_ids
9791005
tmp_position_ids = paddle.arange(self.model_config.max_model_len).reshape((1, -1))

0 commit comments

Comments
 (0)