Skip to content

Commit 94bcb89

Browse files
authored
fix: allow pass cpu generator for helios (#13228)
* allow pass cpu generator for helios * allow pass cpu generator for helios * allow pass cpu generator for helios * patch
1 parent 8ea908f commit 94bcb89

File tree

1 file changed

+4
-1
lines changed

1 file changed

+4
-1
lines changed

src/diffusers/pipelines/helios/pipeline_helios_pyramid.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -456,6 +456,8 @@ def sample_block_noise(
456456
# the output will be non-deterministic and may produce incorrect results in CP context.
457457
if generator is None:
458458
generator = torch.Generator(device=device)
459+
elif isinstance(generator, list):
460+
generator = generator[0]
459461

460462
gamma = self.scheduler.config.gamma
461463
_, ph, pw = patch_size
@@ -470,7 +472,8 @@ def sample_block_noise(
470472

471473
L = torch.linalg.cholesky(cov)
472474
block_number = batch_size * channel * num_frames * (height // ph) * (width // pw)
473-
z = torch.randn(block_number, block_size, device=device, generator=generator)
475+
z = torch.randn(block_number, block_size, generator=generator, device=generator.device)
476+
z = z.to(device=device)
474477
noise = z @ L.T
475478

476479
noise = noise.view(batch_size, channel, num_frames, height // ph, width // pw, ph, pw)

0 commit comments

Comments
 (0)