3131 enable_radix_tree_timer_merge ,
3232 get_radix_tree_merge_update_delta ,
3333)
34- from lightllm .distributed import dist_group_manager
34+ from lightllm .distributed .communication_op import (
35+ dist_group_manager ,
36+ all_gather_into_tensor ,
37+ all_reduce ,
38+ broadcast ,
39+ )
3540from lightllm .server .core .objs .shm_objs_io_buffer import ShmObjsIOBuffer
3641from lightllm .server .router .model_infer .mode_backend .overlap_events import OverlapEventManager , OverlapEventPack
3742from lightllm .models .deepseek_mtp .model import Deepseek3MTPModel
@@ -368,7 +373,7 @@ def _try_read_new_reqs_normal(self):
368373 self .node_broadcast_tensor .fill_ (0 )
369374
370375 src_rank_id = self .args .node_rank * self .node_world_size
371- dist . broadcast (self .node_broadcast_tensor , src = src_rank_id , group = self .node_nccl_group , async_op = False )
376+ broadcast (self .node_broadcast_tensor , src = src_rank_id , group = self .node_nccl_group , async_op = False )
372377 new_buffer_is_ready = self .node_broadcast_tensor .detach ().item ()
373378 if new_buffer_is_ready :
374379 self ._read_reqs_buffer_and_init_reqs ()
@@ -382,7 +387,7 @@ def _try_read_new_reqs_normal(self):
382387 self .node_broadcast_tensor .fill_ (0 )
383388
384389 src_rank_id = self .args .node_rank * self .node_world_size
385- dist . broadcast (self .node_broadcast_tensor , src = src_rank_id , group = self .node_nccl_group , async_op = False )
390+ broadcast (self .node_broadcast_tensor , src = src_rank_id , group = self .node_nccl_group , async_op = False )
386391 new_buffer_is_ready = self .node_broadcast_tensor .detach ().item ()
387392 if new_buffer_is_ready :
388393 self ._read_nixl_trans_io_buffer_and_update_req_status ()
@@ -396,7 +401,7 @@ def _try_read_new_reqs_multinode_tp(self):
396401 self .multinode_tp_gather_item_tensor .fill_ (1 )
397402 else :
398403 self .multinode_tp_gather_item_tensor .fill_ (0 )
399- dist . all_gather_into_tensor (
404+ all_gather_into_tensor (
400405 self .multinode_tp_all_gather_tensor ,
401406 self .multinode_tp_gather_item_tensor ,
402407 group = self .multinode_tp_nccl_group ,
@@ -806,12 +811,12 @@ def _dp_all_gather_prefill_and_decode_req_num(
806811 """
807812 current_dp_prefill_num = len (prefill_reqs )
808813 self .dp_gather_item_tensor .fill_ (current_dp_prefill_num )
809- dist . all_gather_into_tensor (self .dp_all_gather_tensor , self .dp_gather_item_tensor , group = None , async_op = False )
814+ all_gather_into_tensor (self .dp_all_gather_tensor , self .dp_gather_item_tensor , group = None , async_op = False )
810815 dp_prefill_req_nums = self .dp_all_gather_tensor .cpu ().numpy ()
811816
812817 current_dp_decode_num = len (decode_reqs )
813818 self .dp_gather_item_tensor .fill_ (current_dp_decode_num )
814- dist . all_gather_into_tensor (self .dp_all_gather_tensor , self .dp_gather_item_tensor , group = None , async_op = False )
819+ all_gather_into_tensor (self .dp_all_gather_tensor , self .dp_gather_item_tensor , group = None , async_op = False )
815820 dp_decode_req_nums = self .dp_all_gather_tensor .cpu ().numpy ()
816821
817822 return dp_prefill_req_nums , dp_decode_req_nums
@@ -822,7 +827,7 @@ def _dp_all_reduce_decode_req_num(self, decode_reqs: List[InferReq]) -> int:
822827 """
823828 current_dp_decode_num = len (decode_reqs )
824829 self .dp_reduce_tensor .fill_ (current_dp_decode_num )
825- dist . all_reduce (self .dp_reduce_tensor , op = dist .ReduceOp .MAX , group = None , async_op = False )
830+ all_reduce (self .dp_reduce_tensor , op = dist .ReduceOp .MAX , group = None , async_op = False )
826831 max_decode_num = self .dp_reduce_tensor .item ()
827832 return max_decode_num
828833
0 commit comments