Skip to content

Commit d8cdda8

Browse files
[RL][Cherry-Pick] Fix the out-of-bounds issue caused by int32 in the R3 kernel (PaddlePaddle#7155)
* [RL]Perf: Optimize batch delete prefix and fused put in R3 (PaddlePaddle#6604) * Optimizate delete batch and fused put * refine code * refine code * refine code * Support suspend r3 * [RL] Fix R3 Empty bug with TP=1 (PaddlePaddle#6777) * Fix int32 overflow * refine code --------- Co-authored-by: YuBaoku <49938469+EmmonsCurse@users.noreply.github.com>
1 parent 3c7ca62 commit d8cdda8

4 files changed

Lines changed: 75 additions & 41 deletions

File tree

fastdeploy/config.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1764,6 +1764,9 @@ def __init__(self, args: dict):
17641764
else:
17651765
self.metrics_port = self.api_server_port
17661766

1767+
def __str__(self):
1768+
return json.dumps({key: value for key, value in self.__dict__.items()})
1769+
17671770

17681771
class CommitConfig:
17691772
"""
@@ -1877,6 +1880,9 @@ def to_json_string(self):
18771880
"""
18781881
return json.dumps({key: value for key, value in self.__dict__.items()})
18791882

1883+
def __str__(self):
1884+
return self.to_json_string()
1885+
18801886

18811887
class FDConfig:
18821888
"""

fastdeploy/envs.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -253,6 +253,8 @@ def _validate_split_kv_size(value: int) -> int:
253253
"FD_SAVE_OUTPUT_CACHE_FOR_PREEMPTED_REQUEST": lambda: bool(
254254
int(os.getenv("FD_SAVE_OUTPUT_CACHE_FOR_PREEMPTED_REQUEST", "1"))
255255
),
256+
# Suspend rollouting routing replay
257+
"FD_SUSPEND_ROUTING_REPLAY": lambda: bool(int(os.getenv("FD_SUSPEND_ROUTING_REPLAY", "0"))),
256258
# train-infer consistency, used in RL
257259
# Whether to align RoPE and moe gate precision with training
258260
"FD_ENABLE_RL": lambda: int(os.getenv("FD_ENABLE_RL", "0")),

fastdeploy/model_executor/layers/moe/routing_indices_cache.py

Lines changed: 62 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@ def _save_routing_kernel(
5454
TOP_K,
5555
NUM_HIDDEN_LAYERS,
5656
MAX_MODEL_LEN,
57+
MAX_NUM_SEQS,
5758
BLOCK_SIZE_M: tl.constexpr,
5859
BLOCK_SIZE_K: tl.constexpr,
5960
):
@@ -63,45 +64,37 @@ def _save_routing_kernel(
6364
token_mask = token_offsets < TOKEN_NUM
6465

6566
k_offsets = tl.arange(0, BLOCK_SIZE_K)
66-
6767
k_mask = k_offsets < TOP_K
6868

6969
topk_ids_ptrs = TOPK_IDS_PTR + token_offsets[:, None] * TOP_K + k_offsets[None, :]
70-
# [BLOCK_SIZE_M, BLOCK_SIZE_K]
71-
7270
load_mask = token_mask[:, None] & k_mask[None, :]
73-
topk_vals = tl.load(topk_ids_ptrs, mask=load_mask)
74-
75-
batch_ids = tl.load(BATCH_ID_PER_TOKEN_PTR + token_offsets, mask=token_mask)
76-
pad_mask = token_mask & (batch_ids != -1)
77-
# [0, 3, 4, 10, 12][0, 0, 0, 0, 2, 2, 2, 2, 2, 2, 3, 3]
78-
# -> [0, 0, 0, 0, 4, 4, 4, 4, 4, 4, 10, 10]
79-
# [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11] - [0, 0, 0, 0, 4, 4, 4, 4, 4, 4, 10, 10]
80-
# -> [0, 1, 2, 3, 0, 1, 2, 3, 4, 5, 0, 1]
81-
start_offsets = tl.load(CU_SEQLENS_Q_PTR + batch_ids, mask=pad_mask)
71+
topk_vals = tl.load(topk_ids_ptrs, mask=load_mask, other=-1)
72+
73+
batch_ids = tl.load(BATCH_ID_PER_TOKEN_PTR + token_offsets, mask=token_mask, other=-1)
74+
75+
batch_mask = (batch_ids >= 0) & (batch_ids < MAX_NUM_SEQS)
76+
pad_mask = token_mask & (batch_ids != -1) & batch_mask
77+
78+
start_offsets = tl.load(CU_SEQLENS_Q_PTR + batch_ids, mask=pad_mask, other=0)
8279
token_relative_index = token_offsets - start_offsets
8380

84-
# [BLOCK_SIZE_M]
85-
len_decoder = tl.load(SEQ_LENS_DECODER_PTR + batch_ids, mask=pad_mask)
81+
len_decoder = tl.load(SEQ_LENS_DECODER_PTR + batch_ids, mask=pad_mask, other=0)
8682
token_seq_pos = len_decoder + token_relative_index
8783

88-
STRIDE_BUF_SEQ = MAX_MODEL_LEN * NUM_HIDDEN_LAYERS * TOP_K
89-
STRIDE_BUF_TOKEN = NUM_HIDDEN_LAYERS * TOP_K
84+
STRIDE_BUF_SEQ = tl.cast(MAX_MODEL_LEN * NUM_HIDDEN_LAYERS * TOP_K, tl.int64)
85+
STRIDE_BUF_TOKEN = tl.cast(NUM_HIDDEN_LAYERS * TOP_K, tl.int64)
9086
STRIDE_BUF_LAYER = TOP_K
9187

92-
# [BLOCK_SIZE_M, BLOCK_SIZE_K]
9388
output_ptrs = (
9489
ROUTING_REPLAY_TABLE_PTR
95-
+ batch_ids[:, None] * STRIDE_BUF_SEQ
96-
+ token_seq_pos[:, None] * STRIDE_BUF_TOKEN
97-
+ LAYER_IDX * STRIDE_BUF_LAYER
90+
+ tl.cast(batch_ids[:, None], tl.int64) * STRIDE_BUF_SEQ
91+
+ tl.cast(token_seq_pos[:, None], tl.int64) * STRIDE_BUF_TOKEN
92+
+ tl.cast(LAYER_IDX, tl.int64) * STRIDE_BUF_LAYER
9893
+ k_offsets[None, :]
9994
)
10095

101-
pos_mask = token_seq_pos < MAX_MODEL_LEN
96+
pos_mask = (token_seq_pos >= 0) & (token_seq_pos < MAX_MODEL_LEN)
10297
pos_mask = pos_mask & pad_mask
103-
104-
# [BLOCK_SIZE_M, BLOCK_SIZE_K]
10598
pos_mask = pos_mask[:, None] & k_mask[None, :]
10699

107100
final_mask = load_mask & pos_mask
@@ -120,10 +113,10 @@ def save_routing_to_buffer(
120113
ep_size: int,
121114
tp_group: dist.communication.group.Group,
122115
):
116+
token_num_per_rank = topk_ids.shape[0]
117+
if token_num_per_rank == 0:
118+
return
123119
if tp_size > 1 and ep_size > 1:
124-
token_num_per_rank = topk_ids.shape[0]
125-
if token_num_per_rank == 0:
126-
return
127120
topk_ids_all = paddle.zeros([token_num_per_rank * tp_size, topk_ids.shape[1]], dtype=topk_ids.dtype)
128121
paddle.distributed.all_gather(topk_ids_all, topk_ids, tp_group)
129122
topk_ids = topk_ids_all[: batch_id_per_token.shape[0], :]
@@ -150,6 +143,7 @@ def save_routing_to_buffer(
150143
TOP_K=top_k,
151144
NUM_HIDDEN_LAYERS=num_hidden_layers,
152145
MAX_MODEL_LEN=max_model_len,
146+
MAX_NUM_SEQS=max_num_seqs,
153147
BLOCK_SIZE_M=BLOCK_SIZE_M,
154148
BLOCK_SIZE_K=BLOCK_SIZE_K,
155149
)
@@ -166,6 +160,7 @@ def __init__(self, fd_config: FDConfig, block_table, total_block_num):
166160
self.num_moe_layers = fd_config.model_config.num_hidden_layers - fd_config.model_config.moe_layer_start_index
167161
self.only_last_turn = fd_config.routing_replay_config.only_last_turn
168162
self.use_fused_put = fd_config.routing_replay_config.use_fused_put
163+
logger.info(f"[R3] Rollout Routing Replay Congfig: {fd_config.routing_replay_config}")
169164
if fd_config.model_config.architectures[0] == "Glm4MoeForCausalLM":
170165
self.moe_top_k = fd_config.model_config.num_experts_per_tok
171166
else:
@@ -186,6 +181,17 @@ def __init__(self, fd_config: FDConfig, block_table, total_block_num):
186181
)
187182
self._store_wrapper.start_store_warpper()
188183

184+
# Suspend Routing Replay
185+
self.suspend_routing_replay = False
186+
self.update_suspend_routing_replay()
187+
188+
def update_suspend_routing_replay(self):
189+
"""Allow RL to use R3 in different training rounds"""
190+
# TODO(gongshaotian): Delete this func
191+
suspend_routing_replay = os.environ.get("FD_SUSPEND_ROUTING_REPLAY", "0")
192+
self.suspend_routing_replay = bool(int(suspend_routing_replay))
193+
logger.info(f"[R3] Update FD_SUSPEND_ROUTING_REPLAY: {self.suspend_routing_replay}")
194+
189195
def _init_routing_cache(self, dtype: str, total_block_num: int):
190196
"""Initialize the device buffer and host buffer."""
191197

@@ -341,6 +347,11 @@ def _put_request_to_store(
341347
seq_lens_decoder,
342348
):
343349
if self.tp_rank == 0:
350+
# TODO(gongshaotian): Delete the suspend func
351+
if self.suspend_routing_replay:
352+
logger.info(f"[R3] Suspend Routing Replay is enabled, skip putting request {request_id} to store")
353+
return
354+
344355
before_put_request_time = time.perf_counter()
345356

346357
# Collect the routing of finished request
@@ -351,16 +362,19 @@ def _put_request_to_store(
351362

352363
if self.use_fused_put:
353364
self._store_wrapper.submit_put_task(routing_indices=batch_buffer, rollout_id=rollout_id)
365+
# Only store the routing of last turn
366+
if self.only_last_turn:
367+
self._store_wrapper.submit_clear_prefix_batch_task(rollout_id=rollout_id)
368+
354369
else:
355370
for layer_id in range(self.num_moe_layers):
356371
layer_buffer = batch_buffer[layer_id]
357372
self._store_wrapper.submit_put_task(
358373
routing_indices=layer_buffer, rollout_id=rollout_id, layer_idx=layer_id
359374
)
360-
361-
# Only store the routing of last turn
362-
if self.only_last_turn:
363-
self._store_wrapper.submit_clear_prefix_batch_task(rollout_id=rollout_id)
375+
# Only store the routing of last turn
376+
if self.only_last_turn:
377+
self._store_wrapper.submit_clear_prefix_batch_task(rollout_id=rollout_id, layer_idx=layer_id)
364378

365379
logger.info(f"[R3] Submit {request_id} time cost: {time.perf_counter() - before_put_request_time}")
366380

@@ -481,7 +495,6 @@ def _monitor_queue_load(self):
481495
if qsize > self.queue_max_size * 0.8:
482496
logger.warning(
483497
f"[Monitor] Queue load is HIGH: {qsize}/{self.queue_max_size}. "
484-
f"Dropped tasks so far: {self._dropped_tasks}. "
485498
"Consider increasing max_workers or queue_max_size."
486499
)
487500
logger.debug(f"[Monitor] Queue load: {qsize}/{self.queue_max_size}")
@@ -523,22 +536,26 @@ def submit_clear_store_task(self) -> None:
523536
raise RuntimeError("Queue is FULL. Dropping put task for key: clear_store. ")
524537
logger.info(f"[R3] Submit clear task, cost time: {time.perf_counter()-start_time} s")
525538

526-
def submit_clear_prefix_batch_task(self, rollout_id) -> None:
539+
def submit_clear_prefix_batch_task(self, rollout_id, layer_idx: int = None) -> None:
527540
"""Submit clear prefix batch task"""
528541
if not self._sotre_process_running:
529542
raise RuntimeError("Store not started.")
530-
prefix_batch = self.get_needed_clear_ids(rollout_id)
531-
532-
if prefix_batch is None:
543+
prefix_batch_id = self.get_needed_clear_ids(rollout_id)
544+
if prefix_batch_id is None:
533545
return
534546
start_time = time.perf_counter()
535-
task: StoreTask = {"task_type": "clear_prefix_batch", "key": prefix_batch, "data": None}
547+
if layer_idx is not None:
548+
rdma_rollout_key = f"{prefix_batch_id}_{layer_idx}"
549+
else:
550+
rdma_rollout_key = prefix_batch_id
551+
552+
task: StoreTask = {"task_type": "clear_prefix_batch", "key": rdma_rollout_key, "data": None}
536553
try:
537554
self._task_queue.put_nowait(task)
538555
except Exception:
539556
raise RuntimeError("Queue is FULL. Dropping put task for key: clear_store. ")
540557
logger.info(
541-
f"[R3] Submit clear prefix batch task for key: {prefix_batch}, cost time: {time.perf_counter()-start_time} s"
558+
f"[R3] Submit clear prefix batch task for key: {prefix_batch_id}, cost time: {time.perf_counter()-start_time} s"
542559
)
543560

544561
def get_needed_clear_ids(self, roullout_id: str) -> Optional[str]:
@@ -615,7 +632,7 @@ def run(self):
615632
self._task_queue.task_done()
616633
raise RuntimeError(f"Error during processing task. {e}")
617634

618-
logger.info(f"[Consumer Process {Process.current_process().pid}] Shutdown.")
635+
logger.info("RoutingReplay Consumer Process Shutdown.")
619636

620637
def process_put_task(self, store_task: StoreTask) -> None:
621638
try:
@@ -838,13 +855,18 @@ def __init__(self, routing_replay_config) -> None:
838855
async def put(self, routing_key: str, routing_indices: np.ndarray) -> None:
839856
"""Put the routing indices into store"""
840857
time_before_put = time.perf_counter()
841-
result = await self.p2p_client.put(routing_key, routing_indices)
858+
if len(routing_indices.shape) == 3:
859+
# NOTE(gongshaotian) Fused put with bytes data
860+
routing_bytes = routing_indices.tobytes()
861+
result = await self.p2p_client.put(routing_key, routing_bytes)
862+
else:
863+
result = await self.p2p_client.put(routing_key, routing_indices)
842864
logger.info(f"[R3] The routing key {routing_key}, put cost is {time.perf_counter()-time_before_put}s")
843865
return result
844866

845867
async def clear_prefix_batch(self, routing_prefix_key: str):
846868
time_before_clear = time.perf_counter()
847-
result = await self.p2p_client.delete_prefix_batch([routing_prefix_key])
869+
result = await self.p2p_client.delete_batch([routing_prefix_key])
848870
logger.info(
849871
f"[R3] The clear routing prefix key {routing_prefix_key}, cost is {time.perf_counter()-time_before_clear}s"
850872
)

fastdeploy/worker/gpu_model_runner.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2857,9 +2857,13 @@ def update_parameters(self, pid):
28572857
# Recapture CUDAGraph
28582858
if self.use_cudagraph:
28592859
self.capture_model()
2860+
# Rollout Routing Replay
2861+
if self.fd_config.routing_replay_config.enable_routing_replay:
2862+
# TODO(gongshaotian): Delete suspend func
2863+
self.routing_replay_manager.update_suspend_routing_replay()
2864+
28602865
# Send single
28612866
self.dynamic_weight_manager.finalize_update(pid)
2862-
28632867
self.dynamic_weight_manager._log_memory("dynamic weight manager update all memory")
28642868

28652869
def update_weights(self, version: str = None, verify_checksum: bool = False):

0 commit comments

Comments
 (0)