@@ -449,7 +449,14 @@ def sample_block_noise(
449449 width ,
450450 patch_size : tuple [int , ...] = (1 , 2 , 2 ),
451451 device : torch .device | None = None ,
452+ generator : torch .Generator | None = None ,
452453 ):
454+ # NOTE: A generator must be provided to ensure correct and reproducible results.
455+ # Creating a default generator here is a fallback only — without a fixed seed,
456+ # the output will be non-deterministic and may produce incorrect results in CP context.
457+ if generator is None :
458+ generator = torch .Generator (device = device )
459+
453460 gamma = self .scheduler .config .gamma
454461 _ , ph , pw = patch_size
455462 block_size = ph * pw
@@ -458,13 +465,17 @@ def sample_block_noise(
458465 torch .eye (block_size , device = device ) * (1 + gamma )
459466 - torch .ones (block_size , block_size , device = device ) * gamma
460467 )
461- cov += torch .eye (block_size , device = device ) * 1e-6
462- dist = torch .distributions .MultivariateNormal (torch .zeros (block_size , device = device ), covariance_matrix = cov )
468+ cov += torch .eye (block_size , device = device ) * 1e-8
469+ cov = cov .float () # Upcast to fp32 for numerical stability — cholesky is unreliable in fp16/bf16.
470+
471+ L = torch .linalg .cholesky (cov )
463472 block_number = batch_size * channel * num_frames * (height // ph ) * (width // pw )
473+ z = torch .randn (block_number , block_size , device = device , generator = generator )
474+ noise = z @ L .T
464475
465- noise = dist .sample ((block_number ,)) # [block number, block_size]
466476 noise = noise .view (batch_size , channel , num_frames , height // ph , width // pw , ph , pw )
467477 noise = noise .permute (0 , 1 , 2 , 3 , 5 , 4 , 6 ).reshape (batch_size , channel , num_frames , height , width )
478+
468479 return noise
469480
470481 @property
@@ -918,7 +929,14 @@ def __call__(
918929
919930 batch_size , channel , num_frames , pyramid_height , pyramid_width = latents .shape
920931 noise = self .sample_block_noise (
921- batch_size , channel , num_frames , pyramid_height , pyramid_width , patch_size , device
932+ batch_size ,
933+ channel ,
934+ num_frames ,
935+ pyramid_height ,
936+ pyramid_width ,
937+ patch_size ,
938+ device ,
939+ generator ,
922940 )
923941 noise = noise .to (device = device , dtype = transformer_dtype )
924942 latents = alpha * latents + beta * noise # To fix the block artifact
0 commit comments