Skip to content

Commit 57006d4

Browse files
blueswhenniushengxiaowangzaijun
authored
fix: fix a memleak (#1206)
Co-authored-by: niushengxiao <niushengxiao@sensetime.com> Co-authored-by: wangzaijun <wangzaijun@sensetime.com>
1 parent bc07fe1 commit 57006d4

2 files changed

Lines changed: 45 additions & 7 deletions

File tree

lightllm/distributed/communication_op.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -189,6 +189,8 @@ def all_reduce(
189189
op: ReduceOp = ReduceOp.SUM,
190190
async_op: bool = False,
191191
) -> None:
192+
if _is_single_group(group=group):
193+
return
192194
if isinstance(group, CustomProcessGroup):
193195
return group.all_reduce(input_)
194196
else:
@@ -201,6 +203,9 @@ def all_gather_into_tensor(
201203
group: Optional[Union[ProcessGroup, CustomProcessGroup]] = None,
202204
async_op: bool = False,
203205
) -> None:
206+
if _is_single_group(group=group):
207+
output_.copy_(input_)
208+
return
204209
if isinstance(group, CustomProcessGroup):
205210
return group.all_gather_into_tensor(output_, input_)
206211
else:
@@ -213,6 +218,10 @@ def all_gather(
213218
group: Optional[Union[ProcessGroup, CustomProcessGroup]] = None,
214219
async_op: bool = False,
215220
) -> None:
221+
if _is_single_group(group=group):
222+
if len(output_) > 0:
223+
output_[0].copy_(input_)
224+
return
216225
# todo 目前还没有定制算子的支持。
217226
if isinstance(group, CustomProcessGroup):
218227
return dist.all_gather(output_, input_, group.device_group, async_op)
@@ -227,11 +236,35 @@ def reduce_scatter_tensor(
227236
group: Optional[Union[ProcessGroup, CustomProcessGroup]] = None,
228237
async_op=False,
229238
):
239+
if _is_single_group(group=group):
240+
output.copy_(input)
241+
return
230242
# 目前还没有定制算子实现。
231243
if isinstance(group, CustomProcessGroup):
232244
return dist.reduce_scatter_tensor(output, input, op=op, group=group.device_group, async_op=async_op)
233245
else:
234246
return dist.reduce_scatter_tensor(output, input, op=op, group=group, async_op=async_op)
235247

236248

249+
def broadcast(
250+
tensor: torch.Tensor,
251+
src: int,
252+
group: Optional[Union[ProcessGroup, CustomProcessGroup]] = None,
253+
async_op: bool = False,
254+
) -> None:
255+
if _is_single_group(group=group):
256+
return
257+
if isinstance(group, CustomProcessGroup):
258+
return dist.broadcast(tensor, src=src, group=group.device_group, async_op=async_op)
259+
else:
260+
return dist.broadcast(tensor, src=src, group=group, async_op=async_op)
261+
262+
263+
def _is_single_group(group: Optional[Union[ProcessGroup, CustomProcessGroup]]) -> bool:
264+
if isinstance(group, CustomProcessGroup):
265+
return group.dp_world_size == 1
266+
else:
267+
return dist.get_world_size(group=group) == 1
268+
269+
237270
dist_group_manager = DistributeGroupManager()

lightllm/server/router/model_infer/mode_backend/base_backend.py

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,12 @@
3131
enable_radix_tree_timer_merge,
3232
get_radix_tree_merge_update_delta,
3333
)
34-
from lightllm.distributed import dist_group_manager
34+
from lightllm.distributed.communication_op import (
35+
dist_group_manager,
36+
all_gather_into_tensor,
37+
all_reduce,
38+
broadcast,
39+
)
3540
from lightllm.server.core.objs.shm_objs_io_buffer import ShmObjsIOBuffer
3641
from lightllm.server.router.model_infer.mode_backend.overlap_events import OverlapEventManager, OverlapEventPack
3742
from lightllm.models.deepseek_mtp.model import Deepseek3MTPModel
@@ -368,7 +373,7 @@ def _try_read_new_reqs_normal(self):
368373
self.node_broadcast_tensor.fill_(0)
369374

370375
src_rank_id = self.args.node_rank * self.node_world_size
371-
dist.broadcast(self.node_broadcast_tensor, src=src_rank_id, group=self.node_nccl_group, async_op=False)
376+
broadcast(self.node_broadcast_tensor, src=src_rank_id, group=self.node_nccl_group, async_op=False)
372377
new_buffer_is_ready = self.node_broadcast_tensor.detach().item()
373378
if new_buffer_is_ready:
374379
self._read_reqs_buffer_and_init_reqs()
@@ -382,7 +387,7 @@ def _try_read_new_reqs_normal(self):
382387
self.node_broadcast_tensor.fill_(0)
383388

384389
src_rank_id = self.args.node_rank * self.node_world_size
385-
dist.broadcast(self.node_broadcast_tensor, src=src_rank_id, group=self.node_nccl_group, async_op=False)
390+
broadcast(self.node_broadcast_tensor, src=src_rank_id, group=self.node_nccl_group, async_op=False)
386391
new_buffer_is_ready = self.node_broadcast_tensor.detach().item()
387392
if new_buffer_is_ready:
388393
self._read_nixl_trans_io_buffer_and_update_req_status()
@@ -396,7 +401,7 @@ def _try_read_new_reqs_multinode_tp(self):
396401
self.multinode_tp_gather_item_tensor.fill_(1)
397402
else:
398403
self.multinode_tp_gather_item_tensor.fill_(0)
399-
dist.all_gather_into_tensor(
404+
all_gather_into_tensor(
400405
self.multinode_tp_all_gather_tensor,
401406
self.multinode_tp_gather_item_tensor,
402407
group=self.multinode_tp_nccl_group,
@@ -806,12 +811,12 @@ def _dp_all_gather_prefill_and_decode_req_num(
806811
"""
807812
current_dp_prefill_num = len(prefill_reqs)
808813
self.dp_gather_item_tensor.fill_(current_dp_prefill_num)
809-
dist.all_gather_into_tensor(self.dp_all_gather_tensor, self.dp_gather_item_tensor, group=None, async_op=False)
814+
all_gather_into_tensor(self.dp_all_gather_tensor, self.dp_gather_item_tensor, group=None, async_op=False)
810815
dp_prefill_req_nums = self.dp_all_gather_tensor.cpu().numpy()
811816

812817
current_dp_decode_num = len(decode_reqs)
813818
self.dp_gather_item_tensor.fill_(current_dp_decode_num)
814-
dist.all_gather_into_tensor(self.dp_all_gather_tensor, self.dp_gather_item_tensor, group=None, async_op=False)
819+
all_gather_into_tensor(self.dp_all_gather_tensor, self.dp_gather_item_tensor, group=None, async_op=False)
815820
dp_decode_req_nums = self.dp_all_gather_tensor.cpu().numpy()
816821

817822
return dp_prefill_req_nums, dp_decode_req_nums
@@ -822,7 +827,7 @@ def _dp_all_reduce_decode_req_num(self, decode_reqs: List[InferReq]) -> int:
822827
"""
823828
current_dp_decode_num = len(decode_reqs)
824829
self.dp_reduce_tensor.fill_(current_dp_decode_num)
825-
dist.all_reduce(self.dp_reduce_tensor, op=dist.ReduceOp.MAX, group=None, async_op=False)
830+
all_reduce(self.dp_reduce_tensor, op=dist.ReduceOp.MAX, group=None, async_op=False)
826831
max_decode_num = self.dp_reduce_tensor.item()
827832
return max_decode_num
828833

0 commit comments

Comments
 (0)