Skip to content

Commit 636a71e

Browse files
committed
style: apply ruff format to remaining files
- Format long lines in wan_pipeline_with_logprob.py - Format function calls in hpsv3.py
1 parent 6ff3184 commit 636a71e

2 files changed

Lines changed: 64 additions & 24 deletions

File tree

genrl/diffusers_patch/wan_pipeline_with_logprob.py

Lines changed: 58 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -99,13 +99,16 @@ def sde_step_with_logprob(
9999
# This is also reproducible, because I have set global seed in the trainer.
100100
# Some local seeding would not impact the global seed, and thus the reproducibility.
101101
)
102-
prev_sample = prev_sample_mean + std_dev_t * torch.sqrt(-1 * dt) * variance_noise
102+
prev_sample = (
103+
prev_sample_mean + std_dev_t * torch.sqrt(-1 * dt) * variance_noise
104+
)
103105

104106
if deterministic:
105107
prev_sample = sample + dt * model_output
106108

107109
log_prob = (
108-
-((prev_sample.detach() - prev_sample_mean) ** 2) / (2 * ((std_dev_t * torch.sqrt(-1 * dt)) ** 2))
110+
-((prev_sample.detach() - prev_sample_mean) ** 2)
111+
/ (2 * ((std_dev_t * torch.sqrt(-1 * dt)) ** 2))
109112
- torch.log(std_dev_t * torch.sqrt(-1 * dt))
110113
- torch.log(torch.sqrt(2 * torch.as_tensor(math.pi)))
111114
)
@@ -114,9 +117,9 @@ def sde_step_with_logprob(
114117
std_dev_t = sigma_prev * math.sin(noise_level * math.pi / 2) # sigma_t in paper
115118
pred_original_sample = sample - sigma * model_output # predicted x_0 in paper
116119
noise_estimate = sample + model_output * (1 - sigma) # predicted x_1 in paper
117-
prev_sample_mean = pred_original_sample * (1 - sigma_prev) + noise_estimate * torch.sqrt(
118-
sigma_prev**2 - std_dev_t**2
119-
)
120+
prev_sample_mean = pred_original_sample * (
121+
1 - sigma_prev
122+
) + noise_estimate * torch.sqrt(sigma_prev**2 - std_dev_t**2)
120123

121124
if prev_sample is None:
122125
variance_noise = randn_tensor(
@@ -128,14 +131,18 @@ def sde_step_with_logprob(
128131
prev_sample = prev_sample_mean + std_dev_t * variance_noise
129132

130133
if deterministic:
131-
prev_sample = pred_original_sample * (1 - sigma_prev) + noise_estimate * sigma_prev
134+
prev_sample = (
135+
pred_original_sample * (1 - sigma_prev) + noise_estimate * sigma_prev
136+
)
132137

133138
# remove all constants
134139
log_prob = -((prev_sample.detach() - prev_sample_mean) ** 2)
135140

136141
else:
137142
msg = f"Unknown sde_type: {sde_type}. Must be 'flow_sde' or 'flow_cps'."
138-
raise ValueError(msg)
143+
raise ValueError(
144+
msg
145+
)
139146

140147
# mean along all but batch dimension
141148
log_prob = log_prob.mean(dim=tuple(range(1, log_prob.ndim)))
@@ -205,7 +212,12 @@ def wan_pipeline_with_logprob(
205212
)
206213

207214
if num_frames % self.vae_scale_factor_temporal != 1:
208-
num_frames = num_frames // self.vae_scale_factor_temporal * self.vae_scale_factor_temporal + 1
215+
num_frames = (
216+
num_frames
217+
// self.vae_scale_factor_temporal
218+
* self.vae_scale_factor_temporal
219+
+ 1
220+
)
209221
num_frames = max(num_frames, 1)
210222

211223
self._guidance_scale = guidance_scale
@@ -265,18 +277,24 @@ def wan_pipeline_with_logprob(
265277
f"sde_window_range span ({sde_window_range[1] - sde_window_range[0]}) "
266278
f"must be >= sde_window_size ({sde_window_size})"
267279
)
268-
raise ValueError(msg)
280+
raise ValueError(
281+
msg
282+
)
269283
# Use generator if provided (for training reproducibility), otherwise fallback to random
270284
if generator is not None:
271285
# Extract generator from list if needed
272286
gen = generator[0] if isinstance(generator, list) and len(generator) > 0 else generator
273287
# Use torch.randint with generator for deterministic randomness
274288
max_start = sde_window_range[1] - sde_window_size
275-
start = torch.randint(sde_window_range[0], max_start + 1, (1,), generator=gen, device=device).item()
289+
start = torch.randint(
290+
sde_window_range[0], max_start + 1, (1,), generator=gen, device=device
291+
).item()
276292
else:
277293
# Fallback to Python random (for eval, where generator may not be provided)
278294
# This is safe because eval uses deterministic=True and set_seed at the start
279-
start = random.randint(sde_window_range[0], sde_window_range[1] - sde_window_size)
295+
start = random.randint(
296+
sde_window_range[0], sde_window_range[1] - sde_window_size
297+
)
280298
end = start + sde_window_size
281299
sde_window = (start, end)
282300
# In window mode, initialize all_latents as empty list (will be populated in the loop)
@@ -383,7 +401,9 @@ def wan_pipeline_with_logprob(
383401

384402
latents = callback_outputs.pop("latents", latents)
385403
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
386-
negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
404+
negative_prompt_embeds = callback_outputs.pop(
405+
"negative_prompt_embeds", negative_prompt_embeds
406+
)
387407

388408
# Compute KL reward
389409
if use_window:
@@ -392,7 +412,9 @@ def wan_pipeline_with_logprob(
392412
if in_window:
393413
if kl_reward > 0 and not deterministic:
394414
latent_model_input = (
395-
torch.cat([latents_ori] * 2) if self.do_classifier_free_guidance else latents_ori
415+
torch.cat([latents_ori] * 2)
416+
if self.do_classifier_free_guidance
417+
else latents_ori
396418
)
397419
ref_model = getattr(self, "ref_transformer", None)
398420
if ref_model is not None:
@@ -418,7 +440,9 @@ def wan_pipeline_with_logprob(
418440
# perform guidance
419441
if self.do_classifier_free_guidance:
420442
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
421-
noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
443+
noise_pred = noise_pred_uncond + self.guidance_scale * (
444+
noise_pred_text - noise_pred_uncond
445+
)
422446

423447
(
424448
_,
@@ -440,15 +464,21 @@ def wan_pipeline_with_logprob(
440464
diffusion_clip_value=diffusion_clip_value,
441465
)
442466
assert std_dev_t == ref_std_dev_t
443-
kl = (prev_latents_mean - ref_prev_latents_mean) ** 2 / (2 * std_dev_t**2)
467+
kl = (prev_latents_mean - ref_prev_latents_mean) ** 2 / (
468+
2 * std_dev_t**2
469+
)
444470
kl = kl.mean(dim=tuple(range(1, kl.ndim)))
445471
all_kl.append(kl)
446472
else:
447473
# In window but no KL reward, append zero KL
448474
all_kl.append(torch.zeros(len(latents), device=latents.device))
449475
# Original mode: compute KL for all timesteps (sde_window_size == 0)
450476
elif kl_reward > 0 and not deterministic:
451-
latent_model_input = torch.cat([latents_ori] * 2) if self.do_classifier_free_guidance else latents_ori
477+
latent_model_input = (
478+
torch.cat([latents_ori] * 2)
479+
if self.do_classifier_free_guidance
480+
else latents_ori
481+
)
452482
ref_model = getattr(self, "ref_transformer", None)
453483
if ref_model is not None:
454484
ref_ctx = contextlib.nullcontext()
@@ -473,7 +503,9 @@ def wan_pipeline_with_logprob(
473503
# perform guidance
474504
if self.do_classifier_free_guidance:
475505
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
476-
noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
506+
noise_pred = noise_pred_uncond + self.guidance_scale * (
507+
noise_pred_text - noise_pred_uncond
508+
)
477509

478510
(
479511
_,
@@ -495,15 +527,19 @@ def wan_pipeline_with_logprob(
495527
diffusion_clip_value=diffusion_clip_value,
496528
)
497529
assert std_dev_t == ref_std_dev_t
498-
kl = (prev_latents_mean - ref_prev_latents_mean) ** 2 / (2 * std_dev_t**2)
530+
kl = (prev_latents_mean - ref_prev_latents_mean) ** 2 / (
531+
2 * std_dev_t**2
532+
)
499533
kl = kl.mean(dim=tuple(range(1, kl.ndim)))
500534
all_kl.append(kl)
501535
else:
502536
# no kl reward, we do not need to compute, just put a pre-position value, kl will be 0
503537
all_kl.append(torch.zeros(len(latents), device=latents.device))
504538

505539
# call the callback, if provided
506-
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
540+
if i == len(timesteps) - 1 or (
541+
(i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0
542+
):
507543
progress_bar.update()
508544

509545
self._current_timestep = None
@@ -515,9 +551,9 @@ def wan_pipeline_with_logprob(
515551
.view(1, self.vae.config.z_dim, 1, 1, 1)
516552
.to(latents.device, latents.dtype)
517553
)
518-
latents_std = 1.0 / torch.tensor(self.vae.config.latents_std).view(1, self.vae.config.z_dim, 1, 1, 1).to(
519-
latents.device, latents.dtype
520-
)
554+
latents_std = 1.0 / torch.tensor(self.vae.config.latents_std).view(
555+
1, self.vae.config.z_dim, 1, 1, 1
556+
).to(latents.device, latents.dtype)
521557
latents = latents / latents_std + latents_mean
522558
# Decode one sample at a time to reduce peak memory.
523559
decoded_videos = []

genrl/reward/hpsv3.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -145,7 +145,9 @@ def _fn(
145145
# Repeat the general prompt for all frames
146146
frame_prompts = [general_prompt] * len(frame_paths)
147147
with torch.no_grad(), torch.amp.autocast(device_type=device_type):
148-
frame_rewards_raw = inferencer.reward(frame_prompts, image_paths=frame_paths)
148+
frame_rewards_raw = inferencer.reward(
149+
frame_prompts, image_paths=frame_paths
150+
)
149151

150152
# Extract mu values (mean scores)
151153
# HPSv3 returns a list where each element is [mu, sigma] or a tensor
@@ -224,7 +226,9 @@ def _fn(
224226
# Use the same prompt for all frames in the video
225227
frame_prompts = [prompt] * len(frame_paths)
226228
with torch.no_grad(), torch.amp.autocast(device_type=device_type):
227-
frame_rewards_raw = inferencer.reward(frame_prompts, image_paths=frame_paths)
229+
frame_rewards_raw = inferencer.reward(
230+
frame_prompts, image_paths=frame_paths
231+
)
228232

229233
# Extract mu values (mean scores)
230234
# HPSv3 returns a list where each element is [mu, sigma] or a tensor

0 commit comments

Comments
 (0)