1111
1212from lightllm .common .basemodel .layer_weights .hf_load_utils import load_hf_weights
1313from lightllm .common .basemodel .infer_struct import InferStateInfo
14- from lightllm .common .basemodel .routing_manager import (
15- create_routing_capture_manager ,
16- reset_moe_layer_counter ,
17- get_moe_layer_count ,
18- )
14+ from lightllm .common .basemodel .routing_manager import reset_moe_layer_counter
1915from lightllm .common .kv_cache_mem_manager import MemoryManager
2016from lightllm .common .kv_cache_mem_manager .mem_utils import select_mem_manager_class
2117from lightllm .common .req_manager import ReqManager
@@ -282,45 +278,16 @@ def _init_prefill_cuda_graph(self):
282278 self .prefill_graph .warmup (self )
283279
284280 def _init_custom (self ):
285- if self .args .enable_return_routed_experts :
286- # Get MoE layer count from counter (set during _init_weights)
287- num_moe_layers = get_moe_layer_count ()
288- if num_moe_layers == 0 :
289- logger .warning (
290- "enable_return_routed_experts is set but no MoE layers found. "
291- "Routing capture will not be enabled."
292- )
293- return
294-
295- # Get MoE parameters from model config
296- n_routed_experts = self .config .get ("n_routed_experts" , self .config .get ("num_experts" , 0 ))
297- if n_routed_experts == 0 :
298- logger .warning (
299- "enable_return_routed_experts is set but n_routed_experts=0. "
300- "Routing capture will not be enabled."
301- )
302- return
281+ """Hook for model-specific initialization. Override in subclasses."""
282+ pass
303283
304- topk = self .config .get ("num_experts_per_tok" , 1 )
305- num_experts = n_routed_experts
284+ def _post_forward (self , model_input : ModelInput , microbatch_index : int = 0 ) -> None :
285+ """Hook called after forward pass completes. Override in subclasses for post-processing."""
286+ pass
306287
307- # Check if overlap mode is enabled
308- enable_overlap = getattr (self .args , "enable_decode_microbatch_overlap" , False )
309-
310- logger .info (
311- f"Initializing routing capture: num_moe_layers={ num_moe_layers } , "
312- f"topk={ topk } , num_experts={ num_experts } , enable_overlap={ enable_overlap } "
313- )
314-
315- create_routing_capture_manager (
316- num_moe_layers = num_moe_layers ,
317- topk = topk ,
318- num_experts = num_experts ,
319- batch_max_tokens = self .max_total_token_num ,
320- kv_cache_size = self .mem_manager .size ,
321- enable_overlap = enable_overlap ,
322- )
323- return
288+ def _post_forward_dual (self , model_input0 : ModelInput , model_input1 : ModelInput ) -> None :
289+ """Hook called after dual microbatch forward pass completes. Override in subclasses."""
290+ pass
324291
325292 @torch .no_grad ()
326293 def forward (self , model_input : ModelInput ):
@@ -332,7 +299,7 @@ def forward(self, model_input: ModelInput):
332299 else :
333300 result = self ._decode (model_input )
334301
335- # Note: flush is now handled by backend layer (ChunkedPrefill, DP, etc. )
302+ self . _post_forward ( model_input )
336303 return result
337304
338305 def _create_inferstate (self , model_input : ModelInput , microbatch_index : int = 0 ):
@@ -726,6 +693,7 @@ def microbatch_overlap_prefill(self, model_input0: ModelInput, model_input1: Mod
726693 dist_group_manager .clear_deepep_buffer ()
727694 model_output0 .prefill_mem_indexes_ready_event = prefill_mem_indexes_ready_event
728695 model_output1 .prefill_mem_indexes_ready_event = prefill_mem_indexes_ready_event
696+ self ._post_forward_dual (model_input0 , model_input1 )
729697 return model_output0 , model_output1
730698
731699 @torch .no_grad ()
@@ -819,6 +787,7 @@ def microbatch_overlap_decode(self, model_input0: ModelInput, model_input1: Mode
819787 infer_state1 .init_att_state ()
820788
821789 model_output0 , model_output1 = self ._overlap_tpsp_token_forward (infer_state0 , infer_state1 = infer_state1 )
790+ self ._post_forward_dual (model_input0 , model_input1 )
822791 return model_output0 , model_output1
823792
824793 @final
0 commit comments