Skip to content

Commit 235e0f8

Browse files
committed
fix(agent): dpmpp.py last step high-order correction missing.
1 parent 26b94be commit 235e0f8

2 files changed

Lines changed: 15 additions & 14 deletions

File tree

iddm/model/samples/dpmpp.py

Lines changed: 14 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -117,15 +117,11 @@ def _sample_loop(
117117

118118
# DPM++ 2M (second-order)
119119
if self.order == 2:
120-
if i < len(time_steps) - 1:
121-
# Intermediate noise prediction
122-
x_inter = torch.sqrt(alpha_prev) * x0 + torch.sqrt(
123-
1 - alpha_prev - sigma ** 2) * predicted_noise
124-
predicted_noise_inter = self._get_predicted_noise(model, x_inter, t_prev_tensor, labels,
125-
cfg_scale)
126-
127-
# Second-order correction
128-
predicted_noise = (3 * predicted_noise - predicted_noise_inter) / 2
120+
# 2nd-order correction: only needs model at t_prev, available on all steps
121+
sqrt_term = torch.sqrt(torch.clamp(1 - alpha_prev - sigma ** 2, min=1e-8))
122+
x_inter = torch.sqrt(alpha_prev) * x0 + sqrt_term * predicted_noise
123+
predicted_noise_inter = self._get_predicted_noise(model, x_inter, t_prev_tensor, labels, cfg_scale)
124+
predicted_noise = (3 * predicted_noise - predicted_noise_inter) / 2
129125

130126
# DPM++ 3M (third-order)
131127
elif self.order == 3:
@@ -134,20 +130,25 @@ def _sample_loop(
134130
t_next_tensor = (torch.ones(n) * t_next).long().to(self.device)
135131
alpha_next = self.alpha_hat[t_next_tensor][:, None, None, None]
136132

137-
# First intermediate step, 1e-8 to avoid NaN
133+
# First intermediate step
138134
sqrt_term1 = torch.sqrt(torch.clamp(1 - alpha_prev - sigma ** 2, min=1e-8))
139135
x_inter1 = torch.sqrt(alpha_prev) * x0 + sqrt_term1 * predicted_noise
140136
pred_noise1 = self._get_predicted_noise(model, x_inter1, t_prev_tensor, labels, cfg_scale)
141137

142-
# Second intermediate step, 1e-8 to avoid NaN
138+
# Second intermediate step
143139
sqrt_term2 = torch.sqrt(torch.clamp(1 - alpha_next - sigma ** 2, min=1e-8))
144140
x_inter2 = torch.sqrt(alpha_next) * x0 + sqrt_term2 * pred_noise1
145141
pred_noise2 = self._get_predicted_noise(model, x_inter2, t_next_tensor, labels, cfg_scale)
146142

147143
# Third-order correction
148144
predicted_noise = (23 * predicted_noise - 16 * pred_noise1 + 5 * pred_noise2) / 12
149-
# Or use a more stable variant
150-
# predicted_noise = (18 * predicted_noise - 12 * pred_noise1 + 3 * pred_noise2) / 9
145+
else:
146+
# Last step: no look-ahead available, fallback to 2nd-order correction
147+
sqrt_term_inter = torch.sqrt(torch.clamp(1 - alpha_prev - sigma ** 2, min=1e-8))
148+
x_inter = torch.sqrt(alpha_prev) * x0 + sqrt_term_inter * predicted_noise
149+
predicted_noise_inter = self._get_predicted_noise(model, x_inter, t_prev_tensor, labels,
150+
cfg_scale)
151+
predicted_noise = (3 * predicted_noise - predicted_noise_inter) / 2
151152

152153
# Add noise for stochastic sampling
153154
noise = torch.randn_like(x) if t > 1 else torch.zeros_like(x)

iddm/utils/initializer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -236,7 +236,7 @@ def sample_initializer(sample, image_size, device, schedule_name="linear", **kwa
236236
**kwargs)
237237
elif sample == "dpmpp3m":
238238
diffusion = DPMPlusPlusDiffusion(img_size=image_size, device=device, schedule_name=schedule_name, order=3,
239-
**kwargs)
239+
sample_steps=50, **kwargs)
240240
else:
241241
diffusion = DDPMDiffusion(img_size=image_size, device=device, schedule_name=schedule_name, **kwargs)
242242
logger.warning(msg=f"[{device}]: Setting sample error, we has been automatically set to ddpm.")

0 commit comments

Comments
 (0)