Skip to content

Commit 07c5ba8

Browse files
DN6sayakpaul
andauthored
[Context Parallel] Add support for custom device mesh (#13064)
* add custom mesh support * update --------- Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
1 parent 897aed7 commit 07c5ba8

File tree

3 files changed

+112
-8
lines changed

3 files changed

+112
-8
lines changed

src/diffusers/models/_modeling_parallel.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,16 @@ class ContextParallelConfig:
6060
rotate_method (`str`, *optional*, defaults to `"allgather"`):
6161
Method to use for rotating key/value states across devices in ring attention. Currently, only `"allgather"`
6262
is supported.
63+
ulysses_anything (`bool`, *optional*, defaults to `False`):
64+
Whether to enable "Ulysses Anything" mode, which supports arbitrary sequence lengths and head counts that
65+
are not evenly divisible by `ulysses_degree`. When enabled, `ulysses_degree` must be greater than 1 and
66+
`ring_degree` must be 1.
67+
mesh (`torch.distributed.device_mesh.DeviceMesh`, *optional*):
68+
A custom device mesh to use for context parallelism. If provided, this mesh will be used instead of
69+
creating a new one. This is useful when combining context parallelism with other parallelism strategies
70+
(e.g., FSDP, tensor parallelism) that share the same device mesh. The mesh must have both "ring" and
71+
"ulysses" dimensions. Use size 1 for dimensions not being used (e.g., `mesh_shape=(2, 1, 4)` with
72+
`mesh_dim_names=("ring", "ulysses", "fsdp")` for ring attention only with FSDP).
6373
6474
"""
6575

@@ -68,6 +78,7 @@ class ContextParallelConfig:
6878
convert_to_fp32: bool = True
6979
# TODO: support alltoall
7080
rotate_method: Literal["allgather", "alltoall"] = "allgather"
81+
mesh: torch.distributed.device_mesh.DeviceMesh | None = None
7182
# Whether to enable ulysses anything attention to support
7283
# any sequence lengths and any head numbers.
7384
ulysses_anything: bool = False
@@ -124,7 +135,7 @@ def setup(self, rank: int, world_size: int, device: torch.device, mesh: torch.di
124135
f"The product of `ring_degree` ({self.ring_degree}) and `ulysses_degree` ({self.ulysses_degree}) must not exceed the world size ({world_size})."
125136
)
126137

127-
self._flattened_mesh = self._mesh._flatten()
138+
self._flattened_mesh = self._mesh["ring", "ulysses"]._flatten()
128139
self._ring_mesh = self._mesh["ring"]
129140
self._ulysses_mesh = self._mesh["ulysses"]
130141
self._ring_local_rank = self._ring_mesh.get_local_rank()

src/diffusers/models/modeling_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1567,7 +1567,7 @@ def enable_parallelism(
15671567
mesh = None
15681568
if config.context_parallel_config is not None:
15691569
cp_config = config.context_parallel_config
1570-
mesh = torch.distributed.device_mesh.init_device_mesh(
1570+
mesh = cp_config.mesh or torch.distributed.device_mesh.init_device_mesh(
15711571
device_type=device_type,
15721572
mesh_shape=cp_config.mesh_shape,
15731573
mesh_dim_names=cp_config.mesh_dim_names,

tests/models/testing_utils/parallelism.py

Lines changed: 99 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -60,12 +60,7 @@ def _context_parallel_worker(rank, world_size, master_port, model_class, init_di
6060
model.eval()
6161

6262
# Move inputs to device
63-
inputs_on_device = {}
64-
for key, value in inputs_dict.items():
65-
if isinstance(value, torch.Tensor):
66-
inputs_on_device[key] = value.to(device)
67-
else:
68-
inputs_on_device[key] = value
63+
inputs_on_device = {k: v.to(device) if isinstance(v, torch.Tensor) else v for k, v in inputs_dict.items()}
6964

7065
# Enable context parallelism
7166
cp_config = ContextParallelConfig(**cp_dict)
@@ -89,6 +84,59 @@ def _context_parallel_worker(rank, world_size, master_port, model_class, init_di
8984
dist.destroy_process_group()
9085

9186

87+
def _custom_mesh_worker(
88+
rank,
89+
world_size,
90+
master_port,
91+
model_class,
92+
init_dict,
93+
cp_dict,
94+
mesh_shape,
95+
mesh_dim_names,
96+
inputs_dict,
97+
return_dict,
98+
):
99+
"""Worker function for context parallel testing with a user-provided custom DeviceMesh."""
100+
try:
101+
os.environ["MASTER_ADDR"] = "localhost"
102+
os.environ["MASTER_PORT"] = str(master_port)
103+
os.environ["RANK"] = str(rank)
104+
os.environ["WORLD_SIZE"] = str(world_size)
105+
106+
dist.init_process_group(backend="nccl", rank=rank, world_size=world_size)
107+
108+
torch.cuda.set_device(rank)
109+
device = torch.device(f"cuda:{rank}")
110+
111+
model = model_class(**init_dict)
112+
model.to(device)
113+
model.eval()
114+
115+
inputs_on_device = {k: v.to(device) if isinstance(v, torch.Tensor) else v for k, v in inputs_dict.items()}
116+
117+
# DeviceMesh must be created after init_process_group, inside each worker process.
118+
mesh = torch.distributed.device_mesh.init_device_mesh(
119+
"cuda", mesh_shape=mesh_shape, mesh_dim_names=mesh_dim_names
120+
)
121+
cp_config = ContextParallelConfig(**cp_dict, mesh=mesh)
122+
model.enable_parallelism(config=cp_config)
123+
124+
with torch.no_grad():
125+
output = model(**inputs_on_device, return_dict=False)[0]
126+
127+
if rank == 0:
128+
return_dict["status"] = "success"
129+
return_dict["output_shape"] = list(output.shape)
130+
131+
except Exception as e:
132+
if rank == 0:
133+
return_dict["status"] = "error"
134+
return_dict["error"] = str(e)
135+
finally:
136+
if dist.is_initialized():
137+
dist.destroy_process_group()
138+
139+
92140
@is_context_parallel
93141
@require_torch_multi_accelerator
94142
class ContextParallelTesterMixin:
@@ -126,3 +174,48 @@ def test_context_parallel_inference(self, cp_type):
126174
assert return_dict.get("status") == "success", (
127175
f"Context parallel inference failed: {return_dict.get('error', 'Unknown error')}"
128176
)
177+
178+
@pytest.mark.parametrize(
179+
"cp_type,mesh_shape,mesh_dim_names",
180+
[
181+
("ring_degree", (2, 1, 1), ("ring", "ulysses", "fsdp")),
182+
("ulysses_degree", (1, 2, 1), ("ring", "ulysses", "fsdp")),
183+
],
184+
ids=["ring-3d-fsdp", "ulysses-3d-fsdp"],
185+
)
186+
def test_context_parallel_custom_mesh(self, cp_type, mesh_shape, mesh_dim_names):
187+
if not torch.distributed.is_available():
188+
pytest.skip("torch.distributed is not available.")
189+
190+
if not hasattr(self.model_class, "_cp_plan") or self.model_class._cp_plan is None:
191+
pytest.skip("Model does not have a _cp_plan defined for context parallel inference.")
192+
193+
world_size = 2
194+
init_dict = self.get_init_dict()
195+
inputs_dict = {k: v.cpu() if isinstance(v, torch.Tensor) else v for k, v in self.get_dummy_inputs().items()}
196+
cp_dict = {cp_type: world_size}
197+
198+
master_port = _find_free_port()
199+
manager = mp.Manager()
200+
return_dict = manager.dict()
201+
202+
mp.spawn(
203+
_custom_mesh_worker,
204+
args=(
205+
world_size,
206+
master_port,
207+
self.model_class,
208+
init_dict,
209+
cp_dict,
210+
mesh_shape,
211+
mesh_dim_names,
212+
inputs_dict,
213+
return_dict,
214+
),
215+
nprocs=world_size,
216+
join=True,
217+
)
218+
219+
assert return_dict.get("status") == "success", (
220+
f"Custom mesh context parallel inference failed: {return_dict.get('error', 'Unknown error')}"
221+
)

0 commit comments

Comments
 (0)