@@ -54,6 +54,7 @@ def _save_routing_kernel(
5454 TOP_K ,
5555 NUM_HIDDEN_LAYERS ,
5656 MAX_MODEL_LEN ,
57+ MAX_NUM_SEQS ,
5758 BLOCK_SIZE_M : tl .constexpr ,
5859 BLOCK_SIZE_K : tl .constexpr ,
5960):
@@ -63,45 +64,37 @@ def _save_routing_kernel(
6364 token_mask = token_offsets < TOKEN_NUM
6465
6566 k_offsets = tl .arange (0 , BLOCK_SIZE_K )
66-
6767 k_mask = k_offsets < TOP_K
6868
6969 topk_ids_ptrs = TOPK_IDS_PTR + token_offsets [:, None ] * TOP_K + k_offsets [None , :]
70- # [BLOCK_SIZE_M, BLOCK_SIZE_K]
71-
7270 load_mask = token_mask [:, None ] & k_mask [None , :]
73- topk_vals = tl .load (topk_ids_ptrs , mask = load_mask )
74-
75- batch_ids = tl .load (BATCH_ID_PER_TOKEN_PTR + token_offsets , mask = token_mask )
76- pad_mask = token_mask & (batch_ids != - 1 )
77- # [0, 3, 4, 10, 12][0, 0, 0, 0, 2, 2, 2, 2, 2, 2, 3, 3]
78- # -> [0, 0, 0, 0, 4, 4, 4, 4, 4, 4, 10, 10]
79- # [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11] - [0, 0, 0, 0, 4, 4, 4, 4, 4, 4, 10, 10]
80- # -> [0, 1, 2, 3, 0, 1, 2, 3, 4, 5, 0, 1]
81- start_offsets = tl .load (CU_SEQLENS_Q_PTR + batch_ids , mask = pad_mask )
71+ topk_vals = tl .load (topk_ids_ptrs , mask = load_mask , other = - 1 )
72+
73+ batch_ids = tl .load (BATCH_ID_PER_TOKEN_PTR + token_offsets , mask = token_mask , other = - 1 )
74+
75+ batch_mask = (batch_ids >= 0 ) & (batch_ids < MAX_NUM_SEQS )
76+ pad_mask = token_mask & (batch_ids != - 1 ) & batch_mask
77+
78+ start_offsets = tl .load (CU_SEQLENS_Q_PTR + batch_ids , mask = pad_mask , other = 0 )
8279 token_relative_index = token_offsets - start_offsets
8380
84- # [BLOCK_SIZE_M]
85- len_decoder = tl .load (SEQ_LENS_DECODER_PTR + batch_ids , mask = pad_mask )
81+ len_decoder = tl .load (SEQ_LENS_DECODER_PTR + batch_ids , mask = pad_mask , other = 0 )
8682 token_seq_pos = len_decoder + token_relative_index
8783
88- STRIDE_BUF_SEQ = MAX_MODEL_LEN * NUM_HIDDEN_LAYERS * TOP_K
89- STRIDE_BUF_TOKEN = NUM_HIDDEN_LAYERS * TOP_K
84+ STRIDE_BUF_SEQ = tl . cast ( MAX_MODEL_LEN * NUM_HIDDEN_LAYERS * TOP_K , tl . int64 )
85+ STRIDE_BUF_TOKEN = tl . cast ( NUM_HIDDEN_LAYERS * TOP_K , tl . int64 )
9086 STRIDE_BUF_LAYER = TOP_K
9187
92- # [BLOCK_SIZE_M, BLOCK_SIZE_K]
9388 output_ptrs = (
9489 ROUTING_REPLAY_TABLE_PTR
95- + batch_ids [:, None ] * STRIDE_BUF_SEQ
96- + token_seq_pos [:, None ] * STRIDE_BUF_TOKEN
97- + LAYER_IDX * STRIDE_BUF_LAYER
90+ + tl . cast ( batch_ids [:, None ], tl . int64 ) * STRIDE_BUF_SEQ
91+ + tl . cast ( token_seq_pos [:, None ], tl . int64 ) * STRIDE_BUF_TOKEN
92+ + tl . cast ( LAYER_IDX , tl . int64 ) * STRIDE_BUF_LAYER
9893 + k_offsets [None , :]
9994 )
10095
101- pos_mask = token_seq_pos < MAX_MODEL_LEN
96+ pos_mask = ( token_seq_pos >= 0 ) & ( token_seq_pos < MAX_MODEL_LEN )
10297 pos_mask = pos_mask & pad_mask
103-
104- # [BLOCK_SIZE_M, BLOCK_SIZE_K]
10598 pos_mask = pos_mask [:, None ] & k_mask [None , :]
10699
107100 final_mask = load_mask & pos_mask
@@ -120,10 +113,10 @@ def save_routing_to_buffer(
120113 ep_size : int ,
121114 tp_group : dist .communication .group .Group ,
122115):
116+ token_num_per_rank = topk_ids .shape [0 ]
117+ if token_num_per_rank == 0 :
118+ return
123119 if tp_size > 1 and ep_size > 1 :
124- token_num_per_rank = topk_ids .shape [0 ]
125- if token_num_per_rank == 0 :
126- return
127120 topk_ids_all = paddle .zeros ([token_num_per_rank * tp_size , topk_ids .shape [1 ]], dtype = topk_ids .dtype )
128121 paddle .distributed .all_gather (topk_ids_all , topk_ids , tp_group )
129122 topk_ids = topk_ids_all [: batch_id_per_token .shape [0 ], :]
@@ -150,6 +143,7 @@ def save_routing_to_buffer(
150143 TOP_K = top_k ,
151144 NUM_HIDDEN_LAYERS = num_hidden_layers ,
152145 MAX_MODEL_LEN = max_model_len ,
146+ MAX_NUM_SEQS = max_num_seqs ,
153147 BLOCK_SIZE_M = BLOCK_SIZE_M ,
154148 BLOCK_SIZE_K = BLOCK_SIZE_K ,
155149 )
@@ -166,6 +160,7 @@ def __init__(self, fd_config: FDConfig, block_table, total_block_num):
166160 self .num_moe_layers = fd_config .model_config .num_hidden_layers - fd_config .model_config .moe_layer_start_index
167161 self .only_last_turn = fd_config .routing_replay_config .only_last_turn
168162 self .use_fused_put = fd_config .routing_replay_config .use_fused_put
163+ logger .info (f"[R3] Rollout Routing Replay Congfig: { fd_config .routing_replay_config } " )
169164 if fd_config .model_config .architectures [0 ] == "Glm4MoeForCausalLM" :
170165 self .moe_top_k = fd_config .model_config .num_experts_per_tok
171166 else :
@@ -186,6 +181,17 @@ def __init__(self, fd_config: FDConfig, block_table, total_block_num):
186181 )
187182 self ._store_wrapper .start_store_warpper ()
188183
184+ # Suspend Routing Replay
185+ self .suspend_routing_replay = False
186+ self .update_suspend_routing_replay ()
187+
188+ def update_suspend_routing_replay (self ):
189+ """Allow RL to use R3 in different training rounds"""
190+ # TODO(gongshaotian): Delete this func
191+ suspend_routing_replay = os .environ .get ("FD_SUSPEND_ROUTING_REPLAY" , "0" )
192+ self .suspend_routing_replay = bool (int (suspend_routing_replay ))
193+ logger .info (f"[R3] Update FD_SUSPEND_ROUTING_REPLAY: { self .suspend_routing_replay } " )
194+
189195 def _init_routing_cache (self , dtype : str , total_block_num : int ):
190196 """Initialize the device buffer and host buffer."""
191197
@@ -341,6 +347,11 @@ def _put_request_to_store(
341347 seq_lens_decoder ,
342348 ):
343349 if self .tp_rank == 0 :
350+ # TODO(gongshaotian): Delete the suspend func
351+ if self .suspend_routing_replay :
352+ logger .info (f"[R3] Suspend Routing Replay is enabled, skip putting request { request_id } to store" )
353+ return
354+
344355 before_put_request_time = time .perf_counter ()
345356
346357 # Collect the routing of finished request
@@ -351,16 +362,19 @@ def _put_request_to_store(
351362
352363 if self .use_fused_put :
353364 self ._store_wrapper .submit_put_task (routing_indices = batch_buffer , rollout_id = rollout_id )
365+ # Only store the routing of last turn
366+ if self .only_last_turn :
367+ self ._store_wrapper .submit_clear_prefix_batch_task (rollout_id = rollout_id )
368+
354369 else :
355370 for layer_id in range (self .num_moe_layers ):
356371 layer_buffer = batch_buffer [layer_id ]
357372 self ._store_wrapper .submit_put_task (
358373 routing_indices = layer_buffer , rollout_id = rollout_id , layer_idx = layer_id
359374 )
360-
361- # Only store the routing of last turn
362- if self .only_last_turn :
363- self ._store_wrapper .submit_clear_prefix_batch_task (rollout_id = rollout_id )
375+ # Only store the routing of last turn
376+ if self .only_last_turn :
377+ self ._store_wrapper .submit_clear_prefix_batch_task (rollout_id = rollout_id , layer_idx = layer_id )
364378
365379 logger .info (f"[R3] Submit { request_id } time cost: { time .perf_counter () - before_put_request_time } " )
366380
@@ -481,7 +495,6 @@ def _monitor_queue_load(self):
481495 if qsize > self .queue_max_size * 0.8 :
482496 logger .warning (
483497 f"[Monitor] Queue load is HIGH: { qsize } /{ self .queue_max_size } . "
484- f"Dropped tasks so far: { self ._dropped_tasks } . "
485498 "Consider increasing max_workers or queue_max_size."
486499 )
487500 logger .debug (f"[Monitor] Queue load: { qsize } /{ self .queue_max_size } " )
@@ -523,22 +536,26 @@ def submit_clear_store_task(self) -> None:
523536 raise RuntimeError ("Queue is FULL. Dropping put task for key: clear_store. " )
524537 logger .info (f"[R3] Submit clear task, cost time: { time .perf_counter ()- start_time } s" )
525538
526- def submit_clear_prefix_batch_task (self , rollout_id ) -> None :
539+ def submit_clear_prefix_batch_task (self , rollout_id , layer_idx : int = None ) -> None :
527540 """Submit clear prefix batch task"""
528541 if not self ._sotre_process_running :
529542 raise RuntimeError ("Store not started." )
530- prefix_batch = self .get_needed_clear_ids (rollout_id )
531-
532- if prefix_batch is None :
543+ prefix_batch_id = self .get_needed_clear_ids (rollout_id )
544+ if prefix_batch_id is None :
533545 return
534546 start_time = time .perf_counter ()
535- task : StoreTask = {"task_type" : "clear_prefix_batch" , "key" : prefix_batch , "data" : None }
547+ if layer_idx is not None :
548+ rdma_rollout_key = f"{ prefix_batch_id } _{ layer_idx } "
549+ else :
550+ rdma_rollout_key = prefix_batch_id
551+
552+ task : StoreTask = {"task_type" : "clear_prefix_batch" , "key" : rdma_rollout_key , "data" : None }
536553 try :
537554 self ._task_queue .put_nowait (task )
538555 except Exception :
539556 raise RuntimeError ("Queue is FULL. Dropping put task for key: clear_store. " )
540557 logger .info (
541- f"[R3] Submit clear prefix batch task for key: { prefix_batch } , cost time: { time .perf_counter ()- start_time } s"
558+ f"[R3] Submit clear prefix batch task for key: { prefix_batch_id } , cost time: { time .perf_counter ()- start_time } s"
542559 )
543560
544561 def get_needed_clear_ids (self , roullout_id : str ) -> Optional [str ]:
@@ -615,7 +632,7 @@ def run(self):
615632 self ._task_queue .task_done ()
616633 raise RuntimeError (f"Error during processing task. { e } " )
617634
618- logger .info (f"[ Consumer Process { Process . current_process (). pid } ] Shutdown." )
635+ logger .info ("RoutingReplay Consumer Process Shutdown." )
619636
620637 def process_put_task (self , store_task : StoreTask ) -> None :
621638 try :
@@ -838,13 +855,18 @@ def __init__(self, routing_replay_config) -> None:
838855 async def put (self , routing_key : str , routing_indices : np .ndarray ) -> None :
839856 """Put the routing indices into store"""
840857 time_before_put = time .perf_counter ()
841- result = await self .p2p_client .put (routing_key , routing_indices )
858+ if len (routing_indices .shape ) == 3 :
859+ # NOTE(gongshaotian) Fused put with bytes data
860+ routing_bytes = routing_indices .tobytes ()
861+ result = await self .p2p_client .put (routing_key , routing_bytes )
862+ else :
863+ result = await self .p2p_client .put (routing_key , routing_indices )
842864 logger .info (f"[R3] The routing key { routing_key } , put cost is { time .perf_counter ()- time_before_put } s" )
843865 return result
844866
845867 async def clear_prefix_batch (self , routing_prefix_key : str ):
846868 time_before_clear = time .perf_counter ()
847- result = await self .p2p_client .delete_prefix_batch ([routing_prefix_key ])
869+ result = await self .p2p_client .delete_batch ([routing_prefix_key ])
848870 logger .info (
849871 f"[R3] The clear routing prefix key { routing_prefix_key } , cost is { time .perf_counter ()- time_before_clear } s"
850872 )
0 commit comments