Skip to content

Commit 929ab72

Browse files
committed
fix: style
1 parent 2480388 commit 929ab72

6 files changed

Lines changed: 78 additions & 9 deletions

File tree

examples/instruct_pix2pix/train_instruct_pix2pix_sdxl.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -85,9 +85,11 @@ def log_validation(pipeline, args, accelerator, generator, global_step, is_final
8585
os.makedirs(val_save_dir)
8686

8787
original_image = (
88-
lambda image_url_or_path: load_image(image_url_or_path)
89-
if urlparse(image_url_or_path).scheme
90-
else Image.open(image_url_or_path).convert("RGB")
88+
lambda image_url_or_path: (
89+
load_image(image_url_or_path)
90+
if urlparse(image_url_or_path).scheme
91+
else Image.open(image_url_or_path).convert("RGB")
92+
)
9193
)(args.val_image_url_or_path)
9294

9395
if torch.backends.mps.is_available():

src/diffusers/loaders/peft.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@
4646
logger = logging.get_logger(__name__)
4747

4848
_SET_ADAPTER_SCALE_FN_MAPPING = defaultdict(
49-
lambda: (lambda model_cls, weights: weights),
49+
lambda: lambda model_cls, weights: weights,
5050
{
5151
"UNet2DConditionModel": _maybe_expand_lora_scales,
5252
"UNetMotionModel": _maybe_expand_lora_scales,

src/diffusers/models/_modeling_parallel.py

Lines changed: 62 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,6 @@
3535
# - Unified Attention
3636
# - More dispatcher attention backends
3737
# - CFG/Data Parallel
38-
# - Tensor Parallel
3938

4039

4140
@dataclass
@@ -142,6 +141,63 @@ def setup(self, rank: int, world_size: int, device: torch.device, mesh: torch.di
142141
self._ulysses_local_rank = self._ulysses_mesh.get_local_rank()
143142

144143

144+
@dataclass
145+
class TensorParallelConfig:
146+
"""
147+
Configuration for tensor parallelism.
148+
149+
Tensor parallelism shards weight matrices (column-wise and row-wise) across devices.
150+
Each device computes a partial result; an AllReduce/AllGather at layer boundaries
151+
reconstructs the full output. Uses ``torch.distributed.tensor.parallelize_module``
152+
with ``ColwiseParallel`` / ``RowwiseParallel`` sharding styles.
153+
154+
On Neuron, use the ``_pre_shard_and_tp`` workaround from
155+
``transformer_flux2_neuron_tp`` to avoid the NRT consecutive-reduce-scatter bug
156+
on large tensors (>= 5120x5120).
157+
158+
Args:
159+
tp_degree (`int`, defaults to `1`):
160+
Number of devices to shard across. Must be a divisor of the number of
161+
attention heads (and FFN hidden dimensions) of the model being parallelised.
162+
mesh (`torch.distributed.device_mesh.DeviceMesh`, *optional*):
163+
A custom device mesh to use. If provided, ``tp_degree`` is inferred from
164+
``mesh.size()`` and the argument is ignored. Useful when combining TP with
165+
other parallelism strategies (e.g. CP) that share the same mesh.
166+
"""
167+
168+
tp_degree: int = 1
169+
mesh: torch.distributed.device_mesh.DeviceMesh | None = None
170+
171+
_rank: int = None
172+
_world_size: int = None
173+
_device: torch.device = None
174+
_mesh: torch.distributed.device_mesh.DeviceMesh = None
175+
176+
def __post_init__(self):
177+
if self.tp_degree < 1:
178+
raise ValueError("`tp_degree` must be >= 1.")
179+
180+
def setup(
181+
self,
182+
rank: int,
183+
world_size: int,
184+
device: torch.device,
185+
mesh: torch.distributed.device_mesh.DeviceMesh | None = None,
186+
):
187+
self._rank = rank
188+
self._world_size = world_size
189+
self._device = device
190+
if mesh is not None:
191+
self._mesh = mesh
192+
elif self.mesh is not None:
193+
self._mesh = self.mesh
194+
else:
195+
from torch.distributed.device_mesh import init_device_mesh
196+
197+
device_type = str(device).split(":")[0]
198+
self._mesh = init_device_mesh(device_type, (self.tp_degree,), mesh_dim_names=("tp",))
199+
200+
145201
@dataclass
146202
class ParallelConfig:
147203
"""
@@ -150,9 +206,12 @@ class ParallelConfig:
150206
Args:
151207
context_parallel_config (`ContextParallelConfig`, *optional*):
152208
Configuration for context parallelism.
209+
tensor_parallel_config (`TensorParallelConfig`, *optional*):
210+
Configuration for tensor parallelism.
153211
"""
154212

155213
context_parallel_config: ContextParallelConfig | None = None
214+
tensor_parallel_config: TensorParallelConfig | None = None
156215

157216
_rank: int = None
158217
_world_size: int = None
@@ -173,6 +232,8 @@ def setup(
173232
self._mesh = mesh
174233
if self.context_parallel_config is not None:
175234
self.context_parallel_config.setup(rank, world_size, device, mesh)
235+
if self.tensor_parallel_config is not None:
236+
self.tensor_parallel_config.setup(rank, world_size, device, mesh)
176237

177238

178239
@dataclass(frozen=True)

src/diffusers/pipelines/pipeline_utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
import types
2323
from dataclasses import dataclass
2424
from pathlib import Path
25-
from typing import Any, Callable, Dict, List, Optional, Union, get_args, get_origin, get_type_hints
25+
from typing import Any, Callable, Dict, List, Union, get_args, get_origin, get_type_hints
2626

2727
import httpx
2828
import numpy as np
@@ -68,7 +68,6 @@
6868
is_transformers_version,
6969
logging,
7070
numpy_to_pil,
71-
requires_backends,
7271
)
7372
from ..utils.distributed_utils import is_torch_dist_rank_zero
7473
from ..utils.hub_utils import _check_legacy_sharding_variant_format, load_or_create_model_card, populate_model_card
@@ -2249,6 +2248,7 @@ def _is_pipeline_device_mapped(self):
22492248

22502249
return not is_device_type_map and isinstance(device_map, dict) and len(device_map) > 1
22512250

2251+
22522252
class StableDiffusionMixin:
22532253
r"""
22542254
Helper for DiffusionPipeline with vae and unet.(mainly for LDM such as stable diffusion)

src/diffusers/utils/torch_utils.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,14 @@
3939
import torch
4040
from torch.fft import fftn, fftshift, ifftn, ifftshift
4141

42-
BACKEND_SUPPORTS_TRAINING = {"cuda": True, "xpu": True, "cpu": True, "mps": False, "neuron": False, "default": True}
42+
BACKEND_SUPPORTS_TRAINING = {
43+
"cuda": True,
44+
"xpu": True,
45+
"cpu": True,
46+
"mps": False,
47+
"neuron": False,
48+
"default": True,
49+
}
4350
BACKEND_EMPTY_CACHE = {
4451
"cuda": torch.cuda.empty_cache,
4552
"xpu": torch.xpu.empty_cache,

tests/pipelines/pixart_alpha/test_pixart.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,6 @@
2727
PixArtAlphaPipeline,
2828
PixArtTransformer2DModel,
2929
)
30-
3130
from diffusers.utils.import_utils import is_torch_neuronx_available
3231

3332
from ...testing_utils import (

0 commit comments

Comments
 (0)