Skip to content

Commit 9254417

Browse files
authored
Fix Helios Context Parallelism (#13223)
* fix Helios Context Parallelism * refacotr * make style and quality
1 parent e1b5db5 commit 9254417

File tree

2 files changed

+34
-9
lines changed

2 files changed

+34
-9
lines changed

src/diffusers/models/transformers/transformer_helios.py

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -556,14 +556,21 @@ class HeliosTransformer3DModel(
556556
_keys_to_ignore_on_load_unexpected = ["norm_added_q"]
557557
_repeated_blocks = ["HeliosTransformerBlock"]
558558
_cp_plan = {
559-
"blocks.0": {
559+
# Input split at attn level and ffn level.
560+
"blocks.*.attn1": {
560561
"hidden_states": ContextParallelInput(split_dim=1, expected_dims=3, split_output=False),
561-
},
562-
"blocks.*": {
563-
"temb": ContextParallelInput(split_dim=1, expected_dims=4, split_output=False),
564562
"rotary_emb": ContextParallelInput(split_dim=1, expected_dims=3, split_output=False),
565563
},
566-
"blocks.39": ContextParallelOutput(gather_dim=1, expected_dims=3),
564+
"blocks.*.attn2": {
565+
"hidden_states": ContextParallelInput(split_dim=1, expected_dims=3, split_output=False),
566+
},
567+
"blocks.*.ffn": {
568+
"hidden_states": ContextParallelInput(split_dim=1, expected_dims=3, split_output=False),
569+
},
570+
# Output gather at attn level and ffn level.
571+
**{f"blocks.{i}.attn1": ContextParallelOutput(gather_dim=1, expected_dims=3) for i in range(40)},
572+
**{f"blocks.{i}.attn2": ContextParallelOutput(gather_dim=1, expected_dims=3) for i in range(40)},
573+
**{f"blocks.{i}.ffn": ContextParallelOutput(gather_dim=1, expected_dims=3) for i in range(40)},
567574
}
568575

569576
@register_to_config

src/diffusers/pipelines/helios/pipeline_helios_pyramid.py

Lines changed: 22 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -449,7 +449,14 @@ def sample_block_noise(
449449
width,
450450
patch_size: tuple[int, ...] = (1, 2, 2),
451451
device: torch.device | None = None,
452+
generator: torch.Generator | None = None,
452453
):
454+
# NOTE: A generator must be provided to ensure correct and reproducible results.
455+
# Creating a default generator here is a fallback only — without a fixed seed,
456+
# the output will be non-deterministic and may produce incorrect results in CP context.
457+
if generator is None:
458+
generator = torch.Generator(device=device)
459+
453460
gamma = self.scheduler.config.gamma
454461
_, ph, pw = patch_size
455462
block_size = ph * pw
@@ -458,13 +465,17 @@ def sample_block_noise(
458465
torch.eye(block_size, device=device) * (1 + gamma)
459466
- torch.ones(block_size, block_size, device=device) * gamma
460467
)
461-
cov += torch.eye(block_size, device=device) * 1e-6
462-
dist = torch.distributions.MultivariateNormal(torch.zeros(block_size, device=device), covariance_matrix=cov)
468+
cov += torch.eye(block_size, device=device) * 1e-8
469+
cov = cov.float() # Upcast to fp32 for numerical stability — cholesky is unreliable in fp16/bf16.
470+
471+
L = torch.linalg.cholesky(cov)
463472
block_number = batch_size * channel * num_frames * (height // ph) * (width // pw)
473+
z = torch.randn(block_number, block_size, device=device, generator=generator)
474+
noise = z @ L.T
464475

465-
noise = dist.sample((block_number,)) # [block number, block_size]
466476
noise = noise.view(batch_size, channel, num_frames, height // ph, width // pw, ph, pw)
467477
noise = noise.permute(0, 1, 2, 3, 5, 4, 6).reshape(batch_size, channel, num_frames, height, width)
478+
468479
return noise
469480

470481
@property
@@ -918,7 +929,14 @@ def __call__(
918929

919930
batch_size, channel, num_frames, pyramid_height, pyramid_width = latents.shape
920931
noise = self.sample_block_noise(
921-
batch_size, channel, num_frames, pyramid_height, pyramid_width, patch_size, device
932+
batch_size,
933+
channel,
934+
num_frames,
935+
pyramid_height,
936+
pyramid_width,
937+
patch_size,
938+
device,
939+
generator,
922940
)
923941
noise = noise.to(device=device, dtype=transformer_dtype)
924942
latents = alpha * latents + beta * noise # To fix the block artifact

0 commit comments

Comments
 (0)