@@ -99,13 +99,16 @@ def sde_step_with_logprob(
9999 # This is also reproducible, because I have set global seed in the trainer.
100100 # Some local seeding would not impact the global seed, and thus the reproducibility.
101101 )
102- prev_sample = prev_sample_mean + std_dev_t * torch .sqrt (- 1 * dt ) * variance_noise
102+ prev_sample = (
103+ prev_sample_mean + std_dev_t * torch .sqrt (- 1 * dt ) * variance_noise
104+ )
103105
104106 if deterministic :
105107 prev_sample = sample + dt * model_output
106108
107109 log_prob = (
108- - ((prev_sample .detach () - prev_sample_mean ) ** 2 ) / (2 * ((std_dev_t * torch .sqrt (- 1 * dt )) ** 2 ))
110+ - ((prev_sample .detach () - prev_sample_mean ) ** 2 )
111+ / (2 * ((std_dev_t * torch .sqrt (- 1 * dt )) ** 2 ))
109112 - torch .log (std_dev_t * torch .sqrt (- 1 * dt ))
110113 - torch .log (torch .sqrt (2 * torch .as_tensor (math .pi )))
111114 )
@@ -114,9 +117,9 @@ def sde_step_with_logprob(
114117 std_dev_t = sigma_prev * math .sin (noise_level * math .pi / 2 ) # sigma_t in paper
115118 pred_original_sample = sample - sigma * model_output # predicted x_0 in paper
116119 noise_estimate = sample + model_output * (1 - sigma ) # predicted x_1 in paper
117- prev_sample_mean = pred_original_sample * (1 - sigma_prev ) + noise_estimate * torch . sqrt (
118- sigma_prev ** 2 - std_dev_t ** 2
119- )
120+ prev_sample_mean = pred_original_sample * (
121+ 1 - sigma_prev
122+ ) + noise_estimate * torch . sqrt ( sigma_prev ** 2 - std_dev_t ** 2 )
120123
121124 if prev_sample is None :
122125 variance_noise = randn_tensor (
@@ -128,14 +131,18 @@ def sde_step_with_logprob(
128131 prev_sample = prev_sample_mean + std_dev_t * variance_noise
129132
130133 if deterministic :
131- prev_sample = pred_original_sample * (1 - sigma_prev ) + noise_estimate * sigma_prev
134+ prev_sample = (
135+ pred_original_sample * (1 - sigma_prev ) + noise_estimate * sigma_prev
136+ )
132137
133138 # remove all constants
134139 log_prob = - ((prev_sample .detach () - prev_sample_mean ) ** 2 )
135140
136141 else :
137142 msg = f"Unknown sde_type: { sde_type } . Must be 'flow_sde' or 'flow_cps'."
138- raise ValueError (msg )
143+ raise ValueError (
144+ msg
145+ )
139146
140147 # mean along all but batch dimension
141148 log_prob = log_prob .mean (dim = tuple (range (1 , log_prob .ndim )))
@@ -205,7 +212,12 @@ def wan_pipeline_with_logprob(
205212 )
206213
207214 if num_frames % self .vae_scale_factor_temporal != 1 :
208- num_frames = num_frames // self .vae_scale_factor_temporal * self .vae_scale_factor_temporal + 1
215+ num_frames = (
216+ num_frames
217+ // self .vae_scale_factor_temporal
218+ * self .vae_scale_factor_temporal
219+ + 1
220+ )
209221 num_frames = max (num_frames , 1 )
210222
211223 self ._guidance_scale = guidance_scale
@@ -265,18 +277,24 @@ def wan_pipeline_with_logprob(
265277 f"sde_window_range span ({ sde_window_range [1 ] - sde_window_range [0 ]} ) "
266278 f"must be >= sde_window_size ({ sde_window_size } )"
267279 )
268- raise ValueError (msg )
280+ raise ValueError (
281+ msg
282+ )
269283 # Use generator if provided (for training reproducibility), otherwise fallback to random
270284 if generator is not None :
271285 # Extract generator from list if needed
272286 gen = generator [0 ] if isinstance (generator , list ) and len (generator ) > 0 else generator
273287 # Use torch.randint with generator for deterministic randomness
274288 max_start = sde_window_range [1 ] - sde_window_size
275- start = torch .randint (sde_window_range [0 ], max_start + 1 , (1 ,), generator = gen , device = device ).item ()
289+ start = torch .randint (
290+ sde_window_range [0 ], max_start + 1 , (1 ,), generator = gen , device = device
291+ ).item ()
276292 else :
277293 # Fallback to Python random (for eval, where generator may not be provided)
278294 # This is safe because eval uses deterministic=True and set_seed at the start
279- start = random .randint (sde_window_range [0 ], sde_window_range [1 ] - sde_window_size )
295+ start = random .randint (
296+ sde_window_range [0 ], sde_window_range [1 ] - sde_window_size
297+ )
280298 end = start + sde_window_size
281299 sde_window = (start , end )
282300 # In window mode, initialize all_latents as empty list (will be populated in the loop)
@@ -383,7 +401,9 @@ def wan_pipeline_with_logprob(
383401
384402 latents = callback_outputs .pop ("latents" , latents )
385403 prompt_embeds = callback_outputs .pop ("prompt_embeds" , prompt_embeds )
386- negative_prompt_embeds = callback_outputs .pop ("negative_prompt_embeds" , negative_prompt_embeds )
404+ negative_prompt_embeds = callback_outputs .pop (
405+ "negative_prompt_embeds" , negative_prompt_embeds
406+ )
387407
388408 # Compute KL reward
389409 if use_window :
@@ -392,7 +412,9 @@ def wan_pipeline_with_logprob(
392412 if in_window :
393413 if kl_reward > 0 and not deterministic :
394414 latent_model_input = (
395- torch .cat ([latents_ori ] * 2 ) if self .do_classifier_free_guidance else latents_ori
415+ torch .cat ([latents_ori ] * 2 )
416+ if self .do_classifier_free_guidance
417+ else latents_ori
396418 )
397419 ref_model = getattr (self , "ref_transformer" , None )
398420 if ref_model is not None :
@@ -418,7 +440,9 @@ def wan_pipeline_with_logprob(
418440 # perform guidance
419441 if self .do_classifier_free_guidance :
420442 noise_pred_uncond , noise_pred_text = noise_pred .chunk (2 )
421- noise_pred = noise_pred_uncond + self .guidance_scale * (noise_pred_text - noise_pred_uncond )
443+ noise_pred = noise_pred_uncond + self .guidance_scale * (
444+ noise_pred_text - noise_pred_uncond
445+ )
422446
423447 (
424448 _ ,
@@ -440,15 +464,21 @@ def wan_pipeline_with_logprob(
440464 diffusion_clip_value = diffusion_clip_value ,
441465 )
442466 assert std_dev_t == ref_std_dev_t
443- kl = (prev_latents_mean - ref_prev_latents_mean ) ** 2 / (2 * std_dev_t ** 2 )
467+ kl = (prev_latents_mean - ref_prev_latents_mean ) ** 2 / (
468+ 2 * std_dev_t ** 2
469+ )
444470 kl = kl .mean (dim = tuple (range (1 , kl .ndim )))
445471 all_kl .append (kl )
446472 else :
447473 # In window but no KL reward, append zero KL
448474 all_kl .append (torch .zeros (len (latents ), device = latents .device ))
449475 # Original mode: compute KL for all timesteps (sde_window_size == 0)
450476 elif kl_reward > 0 and not deterministic :
451- latent_model_input = torch .cat ([latents_ori ] * 2 ) if self .do_classifier_free_guidance else latents_ori
477+ latent_model_input = (
478+ torch .cat ([latents_ori ] * 2 )
479+ if self .do_classifier_free_guidance
480+ else latents_ori
481+ )
452482 ref_model = getattr (self , "ref_transformer" , None )
453483 if ref_model is not None :
454484 ref_ctx = contextlib .nullcontext ()
@@ -473,7 +503,9 @@ def wan_pipeline_with_logprob(
473503 # perform guidance
474504 if self .do_classifier_free_guidance :
475505 noise_pred_uncond , noise_pred_text = noise_pred .chunk (2 )
476- noise_pred = noise_pred_uncond + self .guidance_scale * (noise_pred_text - noise_pred_uncond )
506+ noise_pred = noise_pred_uncond + self .guidance_scale * (
507+ noise_pred_text - noise_pred_uncond
508+ )
477509
478510 (
479511 _ ,
@@ -495,15 +527,19 @@ def wan_pipeline_with_logprob(
495527 diffusion_clip_value = diffusion_clip_value ,
496528 )
497529 assert std_dev_t == ref_std_dev_t
498- kl = (prev_latents_mean - ref_prev_latents_mean ) ** 2 / (2 * std_dev_t ** 2 )
530+ kl = (prev_latents_mean - ref_prev_latents_mean ) ** 2 / (
531+ 2 * std_dev_t ** 2
532+ )
499533 kl = kl .mean (dim = tuple (range (1 , kl .ndim )))
500534 all_kl .append (kl )
501535 else :
502536 # no kl reward, we do not need to compute, just put a pre-position value, kl will be 0
503537 all_kl .append (torch .zeros (len (latents ), device = latents .device ))
504538
505539 # call the callback, if provided
506- if i == len (timesteps ) - 1 or ((i + 1 ) > num_warmup_steps and (i + 1 ) % self .scheduler .order == 0 ):
540+ if i == len (timesteps ) - 1 or (
541+ (i + 1 ) > num_warmup_steps and (i + 1 ) % self .scheduler .order == 0
542+ ):
507543 progress_bar .update ()
508544
509545 self ._current_timestep = None
@@ -515,9 +551,9 @@ def wan_pipeline_with_logprob(
515551 .view (1 , self .vae .config .z_dim , 1 , 1 , 1 )
516552 .to (latents .device , latents .dtype )
517553 )
518- latents_std = 1.0 / torch .tensor (self .vae .config .latents_std ).view (1 , self . vae . config . z_dim , 1 , 1 , 1 ). to (
519- latents . device , latents . dtype
520- )
554+ latents_std = 1.0 / torch .tensor (self .vae .config .latents_std ).view (
555+ 1 , self . vae . config . z_dim , 1 , 1 , 1
556+ ). to ( latents . device , latents . dtype )
521557 latents = latents / latents_std + latents_mean
522558 # Decode one sample at a time to reduce peak memory.
523559 decoded_videos = []
0 commit comments