@@ -356,7 +356,7 @@ def test_continuous_batching_will_allocation_be_successful(
356356 num_free_blocks : int ,
357357 expected_result : bool ,
358358 ) -> None :
359- """Test the will_allocation_be_successful method of PagedAttentionCache, overloading the elevant attributes of
359+ """Test the will_allocation_be_successful method of PagedAttentionCache, overloading the relevant attributes of
360360 a dummy cache."""
361361
362362 if torch_device is None : # this check which should always pass and helps with type checking
@@ -532,15 +532,21 @@ def test_distributed_helper_set_tp_seed_no_dist(self) -> None:
532532 helper .set_tp_seed (seed = None , model_device = torch .device ("cpu" ))
533533
534534 def test_continuous_batching_config_disables_nccl_graph_mixing (self ) -> None :
535- """Test that constructing a ContinuousBatchingConfig sets NCCL_GRAPH_MIXING_SUPPORT=0 by default and only sets
536- it when the disable_nccl_graph_mixing flag is on."""
537- original = os .environ .pop ("NCCL_GRAPH_MIXING_SUPPORT" , None )
535+ """Test that ContinuousBatchingConfig sets NCCL_GRAPH_MIXING_SUPPORT=0 only under a distributed launch
536+ (WORLD_SIZE > 1) and respects the disable_nccl_graph_mixing flag."""
537+ original_nccl = os .environ .pop ("NCCL_GRAPH_MIXING_SUPPORT" , None )
538+ original_ws = os .environ .pop ("WORLD_SIZE" , None )
538539 try :
539- # Default: env var is set to "0"
540+ # Single-GPU launch (no WORLD_SIZE): env var is left untouched
541+ ContinuousBatchingConfig ()
542+ self .assertNotIn ("NCCL_GRAPH_MIXING_SUPPORT" , os .environ )
543+
544+ # Distributed launch (WORLD_SIZE > 1): env var is set to "0"
545+ os .environ ["WORLD_SIZE" ] = "2"
540546 ContinuousBatchingConfig ()
541547 self .assertEqual (os .environ .get ("NCCL_GRAPH_MIXING_SUPPORT" ), "0" )
542548
543- # Explicitly disabled flag: env var is left untouched
549+ # Explicitly disabled flag: env var is left untouched even under a distributed launch
544550 os .environ .pop ("NCCL_GRAPH_MIXING_SUPPORT" , None )
545551 ContinuousBatchingConfig (disable_nccl_graph_mixing = False )
546552 self .assertNotIn ("NCCL_GRAPH_MIXING_SUPPORT" , os .environ )
@@ -550,10 +556,14 @@ def test_continuous_batching_config_disables_nccl_graph_mixing(self) -> None:
550556 ContinuousBatchingConfig ()
551557 self .assertEqual (os .environ .get ("NCCL_GRAPH_MIXING_SUPPORT" ), "1" )
552558 finally :
553- if original is None :
559+ if original_nccl is None :
554560 os .environ .pop ("NCCL_GRAPH_MIXING_SUPPORT" , None )
555561 else :
556- os .environ ["NCCL_GRAPH_MIXING_SUPPORT" ] = original
562+ os .environ ["NCCL_GRAPH_MIXING_SUPPORT" ] = original_nccl
563+ if original_ws is None :
564+ os .environ .pop ("WORLD_SIZE" , None )
565+ else :
566+ os .environ ["WORLD_SIZE" ] = original_ws
557567
558568
559569@require_torch_accelerator
0 commit comments