Skip to content

Commit 9677859

Browse files
sywangyisayakpaul
andauthored
fix parallelism case failure in xpu (#13270)
* fix parallelism case failure in xpu Signed-off-by: Wang, Yi <yi.a.wang@intel.com> * updated Signed-off-by: Wang, Yi <yi.a.wang@intel.com> --------- Signed-off-by: Wang, Yi <yi.a.wang@intel.com> Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
1 parent ed31974 commit 9677859

File tree

1 file changed

+26
-7
lines changed

1 file changed

+26
-7
lines changed

tests/models/testing_utils/parallelism.py

Lines changed: 26 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -26,9 +26,17 @@
2626
from ...testing_utils import (
2727
is_context_parallel,
2828
require_torch_multi_accelerator,
29+
torch_device,
2930
)
3031

3132

33+
# Device configuration mapping
34+
DEVICE_CONFIG = {
35+
"cuda": {"backend": "nccl", "module": torch.cuda},
36+
"xpu": {"backend": "xccl", "module": torch.xpu},
37+
}
38+
39+
3240
def _find_free_port():
3341
"""Find a free port on localhost."""
3442
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
@@ -47,12 +55,17 @@ def _context_parallel_worker(rank, world_size, master_port, model_class, init_di
4755
os.environ["RANK"] = str(rank)
4856
os.environ["WORLD_SIZE"] = str(world_size)
4957

58+
# Get device configuration
59+
device_config = DEVICE_CONFIG.get(torch_device, DEVICE_CONFIG["cuda"])
60+
backend = device_config["backend"]
61+
device_module = device_config["module"]
62+
5063
# Initialize process group
51-
dist.init_process_group(backend="nccl", rank=rank, world_size=world_size)
64+
dist.init_process_group(backend=backend, rank=rank, world_size=world_size)
5265

5366
# Set device for this process
54-
torch.cuda.set_device(rank)
55-
device = torch.device(f"cuda:{rank}")
67+
device_module.set_device(rank)
68+
device = torch.device(f"{torch_device}:{rank}")
5669

5770
# Create model
5871
model = model_class(**init_dict)
@@ -103,10 +116,16 @@ def _custom_mesh_worker(
103116
os.environ["RANK"] = str(rank)
104117
os.environ["WORLD_SIZE"] = str(world_size)
105118

106-
dist.init_process_group(backend="nccl", rank=rank, world_size=world_size)
119+
# Get device configuration
120+
device_config = DEVICE_CONFIG.get(torch_device, DEVICE_CONFIG["cuda"])
121+
backend = device_config["backend"]
122+
device_module = device_config["module"]
107123

108-
torch.cuda.set_device(rank)
109-
device = torch.device(f"cuda:{rank}")
124+
dist.init_process_group(backend=backend, rank=rank, world_size=world_size)
125+
126+
# Set device for this process
127+
device_module.set_device(rank)
128+
device = torch.device(f"{torch_device}:{rank}")
110129

111130
model = model_class(**init_dict)
112131
model.to(device)
@@ -116,7 +135,7 @@ def _custom_mesh_worker(
116135

117136
# DeviceMesh must be created after init_process_group, inside each worker process.
118137
mesh = torch.distributed.device_mesh.init_device_mesh(
119-
"cuda", mesh_shape=mesh_shape, mesh_dim_names=mesh_dim_names
138+
torch_device, mesh_shape=mesh_shape, mesh_dim_names=mesh_dim_names
120139
)
121140
cp_config = ContextParallelConfig(**cp_dict, mesh=mesh)
122141
model.enable_parallelism(config=cp_config)

0 commit comments

Comments
 (0)