2626from ...testing_utils import (
2727 is_context_parallel ,
2828 require_torch_multi_accelerator ,
29+ torch_device ,
2930)
3031
3132
33+ # Device configuration mapping
34+ DEVICE_CONFIG = {
35+ "cuda" : {"backend" : "nccl" , "module" : torch .cuda },
36+ "xpu" : {"backend" : "xccl" , "module" : torch .xpu },
37+ }
38+
39+
3240def _find_free_port ():
3341 """Find a free port on localhost."""
3442 with socket .socket (socket .AF_INET , socket .SOCK_STREAM ) as s :
@@ -47,12 +55,17 @@ def _context_parallel_worker(rank, world_size, master_port, model_class, init_di
4755 os .environ ["RANK" ] = str (rank )
4856 os .environ ["WORLD_SIZE" ] = str (world_size )
4957
58+ # Get device configuration
59+ device_config = DEVICE_CONFIG .get (torch_device , DEVICE_CONFIG ["cuda" ])
60+ backend = device_config ["backend" ]
61+ device_module = device_config ["module" ]
62+
5063 # Initialize process group
51- dist .init_process_group (backend = "nccl" , rank = rank , world_size = world_size )
64+ dist .init_process_group (backend = backend , rank = rank , world_size = world_size )
5265
5366 # Set device for this process
54- torch . cuda .set_device (rank )
55- device = torch .device (f"cuda :{ rank } " )
67+ device_module .set_device (rank )
68+ device = torch .device (f"{ torch_device } :{ rank } " )
5669
5770 # Create model
5871 model = model_class (** init_dict )
@@ -103,10 +116,16 @@ def _custom_mesh_worker(
103116 os .environ ["RANK" ] = str (rank )
104117 os .environ ["WORLD_SIZE" ] = str (world_size )
105118
106- dist .init_process_group (backend = "nccl" , rank = rank , world_size = world_size )
119+ # Get device configuration
120+ device_config = DEVICE_CONFIG .get (torch_device , DEVICE_CONFIG ["cuda" ])
121+ backend = device_config ["backend" ]
122+ device_module = device_config ["module" ]
107123
108- torch .cuda .set_device (rank )
109- device = torch .device (f"cuda:{ rank } " )
124+ dist .init_process_group (backend = backend , rank = rank , world_size = world_size )
125+
126+ # Set device for this process
127+ device_module .set_device (rank )
128+ device = torch .device (f"{ torch_device } :{ rank } " )
110129
111130 model = model_class (** init_dict )
112131 model .to (device )
@@ -116,7 +135,7 @@ def _custom_mesh_worker(
116135
117136 # DeviceMesh must be created after init_process_group, inside each worker process.
118137 mesh = torch .distributed .device_mesh .init_device_mesh (
119- "cuda" , mesh_shape = mesh_shape , mesh_dim_names = mesh_dim_names
138+ torch_device , mesh_shape = mesh_shape , mesh_dim_names = mesh_dim_names
120139 )
121140 cp_config = ContextParallelConfig (** cp_dict , mesh = mesh )
122141 model .enable_parallelism (config = cp_config )
0 commit comments