33import logging
44import os
55import pickle
6+ from datetime import timedelta
67from enum import Enum
78from typing import Any , List , Optional
89
@@ -79,7 +80,7 @@ def init_process_group(
7980 port : int ,
8081 rank : int ,
8182 world_size : int ,
82- timeout : int = 300 ,
83+ timeout : timedelta = timedelta ( seconds = 300 ) ,
8384 ** kwargs ,
8485 ):
8586 self ._host = host
@@ -88,7 +89,9 @@ def init_process_group(
8889 self ._world_size = world_size
8990 self ._device = torch .device ("cuda" , rank )
9091
91- self .pg = StatelessProcessGroup .create (host , port , rank , world_size , store_timeout = timeout )
92+ self .pg = StatelessProcessGroup .create (
93+ host , port , rank , world_size , store_timeout = int (timeout .total_seconds ())
94+ )
9295
9396 from vllm .distributed .device_communicators .pynccl import PyNcclCommunicator
9497
@@ -186,6 +189,27 @@ def new_group(self, ranks):
186189 )
187190 from vllm_ascend .utils import current_stream
188191
192+ class HcclCommConfig (ctypes .Structure ):
193+ _fields_ = [
194+ ("size" , ctypes .c_size_t ),
195+ ("magic_word" , ctypes .c_uint32 ),
196+ ("version" , ctypes .c_uint32 ),
197+ ("reserved" , ctypes .c_uint64 ),
198+ ("hccl_buffer_size" , ctypes .c_uint32 ),
199+ ("hccl_deterministic" , ctypes .c_uint32 ),
200+ ("hccl_comm_name" , ctypes .c_char * 128 ),
201+ ("hccl_udi" , ctypes .c_char * 128 ),
202+ ("hccl_op_expansion_mode" , ctypes .c_uint32 ),
203+ ("hccl_rdma_traffic_class" , ctypes .c_uint32 ),
204+ ("hccl_rdma_service_level" , ctypes .c_uint32 ),
205+ ("hcll_world_rank_id" , ctypes .c_uint32 ),
206+ ("hccl_job_id" , ctypes .c_uint64 ),
207+ ("comm_engine" , ctypes .c_int32 ),
208+ ("thread_num" , ctypes .c_uint32 ),
209+ ("notify_num_per_thread" , ctypes .c_uint32 ),
210+ ("acl_graph_zero_copy_enable" , ctypes .c_uint8 ),
211+ ]
212+
189213 orig_exported_functions = HCCLLibrary .exported_functions
190214 extended_functions = [
191215 # HcclResult HcclAllGather(
@@ -217,7 +241,7 @@ def new_group(self, ranks):
217241 ctypes .POINTER (ctypes .c_uint32 ),
218242 ctypes .c_uint64 ,
219243 ctypes .c_uint32 ,
220- ctypes .POINTER (hcclUniqueId ),
244+ ctypes .POINTER (HcclCommConfig ),
221245 ctypes .POINTER (hcclComm_t ),
222246 ],
223247 ),
@@ -228,27 +252,6 @@ def hccl_all_gather(self, send_buf, recv_buf, count, data_type, comm, stream):
228252 self ._funcs ["HcclAllGather" ](send_buf , recv_buf , count , data_type , comm , stream )
229253 )
230254
231- class HcclCommConfig (ctypes .Structure ):
232- _fields_ = [
233- ("size" , ctypes .c_size_t ),
234- ("magic_word" , ctypes .c_uint32 ),
235- ("version" , ctypes .c_uint32 ),
236- ("reserved" , ctypes .c_uint64 ),
237- ("hccl_buffer_size" , ctypes .c_uint32 ),
238- ("hccl_deterministic" , ctypes .c_uint32 ),
239- ("hccl_comm_name" , ctypes .c_char * 128 ),
240- ("hccl_udi" , ctypes .c_char * 128 ),
241- ("hccl_op_expansion_mode" , ctypes .c_uint32 ),
242- ("hccl_rdma_traffic_class" , ctypes .c_uint32 ),
243- ("hccl_rdma_service_level" , ctypes .c_uint32 ),
244- ("hcll_world_rank_id" , ctypes .c_uint32 ),
245- ("hccl_job_id" , ctypes .c_uint64 ),
246- ("comm_engine" , ctypes .c_int32 ),
247- ("thread_num" , ctypes .c_uint32 ),
248- ("notify_num_per_thread" , ctypes .c_uint32 ),
249- ("acl_graph_zero_copy_enable" , ctypes .c_uint8 ),
250- ]
251-
252255 def hccl_create_subcomm_config (
253256 self , comm , ranks_size , c_rank_ids , subcomm_id , subcomm_rank , comm_config
254257 ):
@@ -274,55 +277,13 @@ def hccl_create_subcomm_config(
274277 class PyHcclCommunicatorEx (PyHcclCommunicator ):
275278 def __init__ (self , group , device ):
276279 super ().__init__ (group , device )
277- self .subcomms = {}
278280 self .subcomm_id = 1
279281
280- def destroy_comm (self ):
281- self .hccl .hcclCommDestroy (self .comm )
282-
283- def all_reduce (
284- self ,
285- in_tensor : torch .Tensor ,
286- op : ReduceOp = ReduceOp .SUM ,
287- stream = None ,
288- ) -> torch .Tensor :
289- if self .disabled :
290- return None
291- assert in_tensor .device == self .device , (
292- f"this hccl communicator is created to work on { self .device } , "
293- f"but the input tensor is on { in_tensor .device } "
294- )
295- out_tensor = torch .empty_like (in_tensor )
296- if stream is None :
297- stream = current_stream ()
298- self .hccl .hcclAllReduce (
299- buffer_type (in_tensor .data_ptr ()),
300- buffer_type (out_tensor .data_ptr ()),
301- in_tensor .numel (),
302- hcclDataTypeEnum .from_torch (in_tensor .dtype ),
303- hcclRedOpTypeEnum .from_torch (op ),
304- self .comm , # todo
305- aclrtStream_t (stream .npu_stream ),
306- )
307- return out_tensor
308-
309- def broadcast (self , tensor : torch .Tensor , src : int , stream = None ):
310- if self .disabled :
311- return None
312- assert tensor .device == self .device , (
313- f"this hccl communicator is created to work on { self .device } , "
314- f"but the input tensor is on { tensor .device } "
315- )
316- if stream is None :
317- stream = current_stream ()
318- self .hccl .hcclBroadcast (
319- buffer_type (tensor .data_ptr ()),
320- tensor .numel (),
321- hcclDataTypeEnum .from_torch (tensor .dtype ),
322- src ,
323- self .comm , # todo
324- aclrtStream_t (stream .npu_stream ),
325- )
282+ def destroy_comm (self , comm = None ):
283+ if comm :
284+ self .hccl .hcclCommDestroy (comm )
285+ else :
286+ self .hccl .hcclCommDestroy (self .comm )
326287
327288 def all_gather (self , out_tensor : torch .Tensor , in_tensor : torch .Tensor , stream = None ):
328289 if self .disabled :
@@ -343,10 +304,7 @@ def all_gather(self, out_tensor: torch.Tensor, in_tensor: torch.Tensor, stream=N
343304 )
344305 return out_tensor
345306
346- def create_subcomm (
347- self ,
348- ranks ,
349- ):
307+ def create_subcomm (self , ranks ):
350308 comm_config = HcclCommConfig (
351309 size = 312 ,
352310 magic_word = 0xF0F0F0F0 ,
@@ -375,7 +333,6 @@ def create_subcomm(
375333 subcomm = self .hccl .hcclCreateSubCommConfig (
376334 self .comm , ranks_size , c_rank_ids , subcomm_id , subcomm_rank , comm_config
377335 )
378- self .subcomms [subcomm_id ] = subcomm
379336 self .subcomm_id += 1
380337 return subcomm
381338
@@ -391,7 +348,7 @@ def init_process_group(
391348 port : int ,
392349 rank : int ,
393350 world_size : int ,
394- timeout : int = 300 ,
351+ timeout : timedelta = timedelta ( seconds = 300 ) ,
395352 ** kwargs ,
396353 ):
397354 self ._host = host
@@ -401,13 +358,15 @@ def init_process_group(
401358 self ._device = torch .device ("npu" , rank )
402359
403360 self .pg = StatelessProcessGroup .create (
404- host , port , rank , world_size , store_timeout = timeout
361+ host , port , rank , world_size , store_timeout = int ( timeout . total_seconds ())
405362 )
406363 self .pyhccl = PyHcclCommunicatorEx (group = self .pg , device = self ._device )
364+ self ._comm = self .pyhccl .comm
407365
408366 def destroy_process_group (self , group = None ):
409367 if group in self .sub_groups :
410- group .pyhccl .destroy_comm ()
368+ subcomm = ctypes .c_void_p (group )
369+ self .pyhccl .destroy_comm (subcomm )
411370 del self .sub_groups [group ]
412371 return
413372
@@ -422,69 +381,70 @@ def is_initialized(self) -> bool:
422381 def all_gather_object (self , object_list : list [Any ], obj : Any , group = None ):
423382 if group :
424383 assert group in self .sub_groups , "invalid sub_group"
425- pyhccl = group . pyhccl
426- else :
427- pyhccl = self . pyhccl
428- _common_all_gather_object (pyhccl , self ._device , self ._world_size , object_list , obj )
384+ subcomm = ctypes . c_void_p ( group )
385+ self . pyhccl . comm = subcomm
386+
387+ _common_all_gather_object (self . pyhccl , self ._device , self ._world_size , object_list , obj )
429388 current_stream ().synchronize ()
430389
390+ if group :
391+ self .pyhccl .comm = self ._comm
392+
431393 def all_reduce (self , tensor : torch .Tensor , op = ReduceOp .SUM , group = None ):
432394 if group :
433395 assert group in self .sub_groups , "invalid sub_group"
434- pyhccl = group .pyhccl
435- else :
436- pyhccl = self .pyhccl
396+ subcomm = ctypes .c_void_p (group )
397+ self .pyhccl .comm = subcomm
437398
438- out_tensor = pyhccl .all_reduce (tensor , op )
399+ out_tensor = self . pyhccl .all_reduce (tensor , op )
439400 current_stream ().synchronize ()
440401 tensor .copy_ (out_tensor )
441402
403+ if group :
404+ self .pyhccl .comm = self ._comm
405+
442406 def broadcast (self , tensor : torch .Tensor , src = None , group = None ):
443407 if group :
444408 assert group in self .sub_groups , "invalid sub_group"
445409 assert src in self .sub_groups [group ], "src rank not in group"
446- pyhccl = group .pyhccl
447- # src is rank id in global world
410+ subcomm = ctypes .c_void_p (group )
411+ self .pyhccl .comm = subcomm
412+ # convert src rank id in default world to subcomm
448413 src = self .sub_groups [group ].index (src )
449- else :
450- pyhccl = self .pyhccl
451414
452- pyhccl .broadcast (tensor , src )
415+ self . pyhccl .broadcast (tensor , src )
453416 current_stream ().synchronize ()
454417
418+ if group :
419+ self .pyhccl .comm = self ._comm
420+
455421 def barrier (self , group = None ):
456422 if group :
457423 assert group in self .sub_groups , "invalid sub_group"
458- pyhccl = group .pyhccl
459- else :
460- pyhccl = self .pyhccl
424+ subcomm = ctypes .c_void_p (group )
425+ self .pyhccl .comm = subcomm
461426
462427 data = torch .zeros (1 , device = self ._rank )
463- pyhccl .all_reduce (data )
428+ self . pyhccl .all_reduce (data )
464429 current_stream ().synchronize ()
465430
431+ if group :
432+ self .pyhccl .comm = self ._comm
433+
466434 def new_group (self , ranks ):
467- # ranks is None or []
435+ # if ranks is None or [], using the world instead
468436 if not ranks :
469- return self
470-
471- host = self ._host
472- port = self ._port
473- rank = self ._rank
437+ ranks = list (range (self ._world_size ))
474438
475- if rank not in ranks :
439+ if self . _rank not in ranks :
476440 return
477441
478- new_rank = ranks .index (rank )
479- new_world_size = len (ranks )
480-
481- new_dist = DistributedHccl ()
482- new_dist .init_process_group (
483- host , port + 10 , new_rank , new_world_size
484- ) # todo host maybe incorrect
485- self .sub_groups [new_dist ] = ranks
486-
487- return new_dist
442+ subcomm = self .pyhccl .create_subcomm (ranks )
443+ value = 0
444+ if subcomm :
445+ value = subcomm .value
446+ self .sub_groups [value ] = ranks
447+ return value
488448
489449except ImportError as e :
490450 pass
0 commit comments