@@ -60,12 +60,7 @@ def _context_parallel_worker(rank, world_size, master_port, model_class, init_di
6060 model .eval ()
6161
6262 # Move inputs to device
63- inputs_on_device = {}
64- for key , value in inputs_dict .items ():
65- if isinstance (value , torch .Tensor ):
66- inputs_on_device [key ] = value .to (device )
67- else :
68- inputs_on_device [key ] = value
63+ inputs_on_device = {k : v .to (device ) if isinstance (v , torch .Tensor ) else v for k , v in inputs_dict .items ()}
6964
7065 # Enable context parallelism
7166 cp_config = ContextParallelConfig (** cp_dict )
@@ -89,6 +84,59 @@ def _context_parallel_worker(rank, world_size, master_port, model_class, init_di
8984 dist .destroy_process_group ()
9085
9186
87+ def _custom_mesh_worker (
88+ rank ,
89+ world_size ,
90+ master_port ,
91+ model_class ,
92+ init_dict ,
93+ cp_dict ,
94+ mesh_shape ,
95+ mesh_dim_names ,
96+ inputs_dict ,
97+ return_dict ,
98+ ):
99+ """Worker function for context parallel testing with a user-provided custom DeviceMesh."""
100+ try :
101+ os .environ ["MASTER_ADDR" ] = "localhost"
102+ os .environ ["MASTER_PORT" ] = str (master_port )
103+ os .environ ["RANK" ] = str (rank )
104+ os .environ ["WORLD_SIZE" ] = str (world_size )
105+
106+ dist .init_process_group (backend = "nccl" , rank = rank , world_size = world_size )
107+
108+ torch .cuda .set_device (rank )
109+ device = torch .device (f"cuda:{ rank } " )
110+
111+ model = model_class (** init_dict )
112+ model .to (device )
113+ model .eval ()
114+
115+ inputs_on_device = {k : v .to (device ) if isinstance (v , torch .Tensor ) else v for k , v in inputs_dict .items ()}
116+
117+ # DeviceMesh must be created after init_process_group, inside each worker process.
118+ mesh = torch .distributed .device_mesh .init_device_mesh (
119+ "cuda" , mesh_shape = mesh_shape , mesh_dim_names = mesh_dim_names
120+ )
121+ cp_config = ContextParallelConfig (** cp_dict , mesh = mesh )
122+ model .enable_parallelism (config = cp_config )
123+
124+ with torch .no_grad ():
125+ output = model (** inputs_on_device , return_dict = False )[0 ]
126+
127+ if rank == 0 :
128+ return_dict ["status" ] = "success"
129+ return_dict ["output_shape" ] = list (output .shape )
130+
131+ except Exception as e :
132+ if rank == 0 :
133+ return_dict ["status" ] = "error"
134+ return_dict ["error" ] = str (e )
135+ finally :
136+ if dist .is_initialized ():
137+ dist .destroy_process_group ()
138+
139+
92140@is_context_parallel
93141@require_torch_multi_accelerator
94142class ContextParallelTesterMixin :
@@ -126,3 +174,48 @@ def test_context_parallel_inference(self, cp_type):
126174 assert return_dict .get ("status" ) == "success" , (
127175 f"Context parallel inference failed: { return_dict .get ('error' , 'Unknown error' )} "
128176 )
177+
178+ @pytest .mark .parametrize (
179+ "cp_type,mesh_shape,mesh_dim_names" ,
180+ [
181+ ("ring_degree" , (2 , 1 , 1 ), ("ring" , "ulysses" , "fsdp" )),
182+ ("ulysses_degree" , (1 , 2 , 1 ), ("ring" , "ulysses" , "fsdp" )),
183+ ],
184+ ids = ["ring-3d-fsdp" , "ulysses-3d-fsdp" ],
185+ )
186+ def test_context_parallel_custom_mesh (self , cp_type , mesh_shape , mesh_dim_names ):
187+ if not torch .distributed .is_available ():
188+ pytest .skip ("torch.distributed is not available." )
189+
190+ if not hasattr (self .model_class , "_cp_plan" ) or self .model_class ._cp_plan is None :
191+ pytest .skip ("Model does not have a _cp_plan defined for context parallel inference." )
192+
193+ world_size = 2
194+ init_dict = self .get_init_dict ()
195+ inputs_dict = {k : v .cpu () if isinstance (v , torch .Tensor ) else v for k , v in self .get_dummy_inputs ().items ()}
196+ cp_dict = {cp_type : world_size }
197+
198+ master_port = _find_free_port ()
199+ manager = mp .Manager ()
200+ return_dict = manager .dict ()
201+
202+ mp .spawn (
203+ _custom_mesh_worker ,
204+ args = (
205+ world_size ,
206+ master_port ,
207+ self .model_class ,
208+ init_dict ,
209+ cp_dict ,
210+ mesh_shape ,
211+ mesh_dim_names ,
212+ inputs_dict ,
213+ return_dict ,
214+ ),
215+ nprocs = world_size ,
216+ join = True ,
217+ )
218+
219+ assert return_dict .get ("status" ) == "success" , (
220+ f"Custom mesh context parallel inference failed: { return_dict .get ('error' , 'Unknown error' )} "
221+ )
0 commit comments