2525
2626import torch
2727from torch import nn
28- from torch .distributed .tensor .device_mesh import DeviceMesh
2928from tqdm import tqdm
3029from tqdm .contrib .logging import logging_redirect_tqdm
3130
@@ -149,7 +148,7 @@ def __init__(
149148 model_device : torch .device ,
150149 model_dtype : torch .dtype ,
151150 scheduler : Scheduler ,
152- device_mesh : DeviceMesh | None ,
151+ distributed_helper : DistributedHelper ,
153152 ) -> None :
154153 """Initialize the continuous batch processor.
155154
@@ -166,7 +165,7 @@ def __init__(
166165 model_device: Device for model inputs/outputs
167166 model_dtype: Data type for model inputs/outputs
168167 scheduler: The [`Scheduler`] to use
169- device_mesh : The device mesh if there is one
168+ distributed_helper : The [`DistributedHelper`] to use
170169 """
171170 self .cache = cache
172171 self .config = config
@@ -179,7 +178,7 @@ def __init__(
179178 self .model_device = model_device
180179 self .model_dtype = model_dtype
181180 self .scheduler = scheduler
182- self .distributed_helper = DistributedHelper ( device_mesh = device_mesh )
181+ self .distributed_helper = distributed_helper
183182
184183 # Generation-related attributes
185184 self .do_sample = getattr (generation_config , "do_sample" , True )
@@ -268,7 +267,7 @@ def _get_new_requests(self) -> None:
268267 payload = (new_states , cancellations )
269268 # Otherwise, the payload is None
270269 else :
271- payload = None
270+ payload = ([], [])
272271
273272 # Broadcast within the TP group. No-op when tp_size == 1, returns the driver's payload unchanged.
274273 payload = self .distributed_helper .tp_broadcast_object (payload )
@@ -521,11 +520,11 @@ def __init__(
521520 self ._request_lock = threading .Lock ()
522521
523522 # Infer if this process is the driver of its own TP group
524- helper = DistributedHelper (device_mesh = getattr (self .model , "_device_mesh" , None ))
525- self .is_tp_driver = helper .is_tp_driver
523+ self . distributed_helper = DistributedHelper (device_mesh = getattr (self .model , "_device_mesh" , None ))
524+ self .is_tp_driver = self . distributed_helper .is_tp_driver
526525 # If TP is on, check if NCCL graph mixing is disabled (helps with performance)
527526 if continuous_batching_config .disable_nccl_graph_mixing :
528- helper .maybe_warn_nccl_graph_mixing ()
527+ self . distributed_helper .maybe_warn_nccl_graph_mixing ()
529528
530529 # Generation config related arguments
531530 num_return_sequences = getattr (generation_config , "num_return_sequences" , None )
@@ -601,6 +600,7 @@ def stop(self, block: bool = True, timeout: float | None = None, keep_for_next_s
601600 # If the manager is not being kept for next session, we clear the batch processor
602601 if not keep_for_next_session :
603602 self .batch_processor = None
603+ self .distributed_helper .destroy_ingress_group ()
604604 # Otherwise, we keep the batch processor and cache the manager as a model attribute
605605 else :
606606 logger .info ("Continuous batching manager will be kept for next session." )
@@ -792,15 +792,13 @@ def _generation_step(self) -> None:
792792 self .batch_processor ._generation_step (self .model )
793793
794794 def _create_batch_processor (self ) -> ContinuousBatchProcessor :
795- # Retrieve the device mesh if there is one
796- device_mesh : DeviceMesh | None = getattr (self .model , "_device_mesh" , None )
797795 # Create the PagedAttentionCache
798796 paged_attention_cache = PagedAttentionCache (
799- self .model .config ,
800- self .continuous_batching_config ,
801- self .model .device ,
802- self .model . dtype ,
803- tp_size = DistributedHelper ( device_mesh = device_mesh ). tp_size , # consistent with the batch processor
797+ config = self .model .config ,
798+ continuous_batching_config = self .continuous_batching_config ,
799+ device = self .model .device ,
800+ distributed_helper = self .distributed_helper ,
801+ dtype = self . model . dtype ,
804802 )
805803 self ._use_prefix_sharing = paged_attention_cache .use_prefix_sharing # update the approximation
806804
@@ -829,7 +827,7 @@ def _create_batch_processor(self) -> ContinuousBatchProcessor:
829827 model_device = self .model .device ,
830828 model_dtype = self .model .dtype ,
831829 scheduler = scheduler (paged_attention_cache ),
832- device_mesh = device_mesh ,
830+ distributed_helper = self . distributed_helper ,
833831 )
834832 return batch_processor
835833
0 commit comments