Skip to content

Commit e437fb9

Browse files
authored
Merge pull request #189 from chairc/dev-agent
sample fix by agent.
2 parents 67bdc5c + 7352f06 commit e437fb9

4 files changed

Lines changed: 18 additions & 19 deletions

File tree

iddm/model/samples/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,3 +24,5 @@
2424
from .ddim import DDIMDiffusion
2525
from .ddpm import DDPMDiffusion
2626
from .plms import PLMSDiffusion
27+
from .dpm2 import DPM2Diffusion
28+
from .dpmpp import DPMPlusPlusDiffusion

iddm/model/samples/dpmpp.py

Lines changed: 14 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -117,15 +117,11 @@ def _sample_loop(
117117

118118
# DPM++ 2M (second-order)
119119
if self.order == 2:
120-
if i < len(time_steps) - 1:
121-
# Intermediate noise prediction
122-
x_inter = torch.sqrt(alpha_prev) * x0 + torch.sqrt(
123-
1 - alpha_prev - sigma ** 2) * predicted_noise
124-
predicted_noise_inter = self._get_predicted_noise(model, x_inter, t_prev_tensor, labels,
125-
cfg_scale)
126-
127-
# Second-order correction
128-
predicted_noise = (3 * predicted_noise - predicted_noise_inter) / 2
120+
# 2nd-order correction: only needs model at t_prev, available on all steps
121+
sqrt_term = torch.sqrt(torch.clamp(1 - alpha_prev - sigma ** 2, min=1e-8))
122+
x_inter = torch.sqrt(alpha_prev) * x0 + sqrt_term * predicted_noise
123+
predicted_noise_inter = self._get_predicted_noise(model, x_inter, t_prev_tensor, labels, cfg_scale)
124+
predicted_noise = (3 * predicted_noise - predicted_noise_inter) / 2
129125

130126
# DPM++ 3M (third-order)
131127
elif self.order == 3:
@@ -134,20 +130,25 @@ def _sample_loop(
134130
t_next_tensor = (torch.ones(n) * t_next).long().to(self.device)
135131
alpha_next = self.alpha_hat[t_next_tensor][:, None, None, None]
136132

137-
# First intermediate step, 1e-8 to avoid NaN
133+
# First intermediate step
138134
sqrt_term1 = torch.sqrt(torch.clamp(1 - alpha_prev - sigma ** 2, min=1e-8))
139135
x_inter1 = torch.sqrt(alpha_prev) * x0 + sqrt_term1 * predicted_noise
140136
pred_noise1 = self._get_predicted_noise(model, x_inter1, t_prev_tensor, labels, cfg_scale)
141137

142-
# Second intermediate step, 1e-8 to avoid NaN
138+
# Second intermediate step
143139
sqrt_term2 = torch.sqrt(torch.clamp(1 - alpha_next - sigma ** 2, min=1e-8))
144140
x_inter2 = torch.sqrt(alpha_next) * x0 + sqrt_term2 * pred_noise1
145141
pred_noise2 = self._get_predicted_noise(model, x_inter2, t_next_tensor, labels, cfg_scale)
146142

147143
# Third-order correction
148144
predicted_noise = (23 * predicted_noise - 16 * pred_noise1 + 5 * pred_noise2) / 12
149-
# Or use a more stable variant
150-
# predicted_noise = (18 * predicted_noise - 12 * pred_noise1 + 3 * pred_noise2) / 9
145+
else:
146+
# Last step: no look-ahead available, fallback to 2nd-order correction
147+
sqrt_term_inter = torch.sqrt(torch.clamp(1 - alpha_prev - sigma ** 2, min=1e-8))
148+
x_inter = torch.sqrt(alpha_prev) * x0 + sqrt_term_inter * predicted_noise
149+
predicted_noise_inter = self._get_predicted_noise(model, x_inter, t_prev_tensor, labels,
150+
cfg_scale)
151+
predicted_noise = (3 * predicted_noise - predicted_noise_inter) / 2
151152

152153
# Add noise for stochastic sampling
153154
noise = torch.randn_like(x) if t > 1 else torch.zeros_like(x)

iddm/model/samples/plms.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -111,11 +111,7 @@ def _sample_loop(
111111
c1 = self.eta * torch.sqrt((1 - alpha_t / alpha_prev) * (1 - alpha_prev) / (1 - alpha_t))
112112
c2 = torch.sqrt((1 - alpha_prev) - c1 ** 2)
113113
p_x = torch.sqrt(alpha_prev) * x0_t + c2 * predicted_noise + c1 * noise
114-
if labels is None and cfg_scale is None:
115-
# Images and time steps input into the model
116-
predicted_noise_next = model(p_x, p_t)
117-
else:
118-
predicted_noise_next = model(p_x, p_t, labels)
114+
predicted_noise_next = self._get_predicted_noise(model, p_x, p_t, labels, cfg_scale)
119115
predicted_noise_prime = (predicted_noise + predicted_noise_next) / 2
120116
elif len(old_eps) == 1:
121117
# 2nd order Pseudo Linear Multistep (Adams-Bashforth)

iddm/utils/initializer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -236,7 +236,7 @@ def sample_initializer(sample, image_size, device, schedule_name="linear", **kwa
236236
**kwargs)
237237
elif sample == "dpmpp3m":
238238
diffusion = DPMPlusPlusDiffusion(img_size=image_size, device=device, schedule_name=schedule_name, order=3,
239-
**kwargs)
239+
sample_steps=50, **kwargs)
240240
else:
241241
diffusion = DDPMDiffusion(img_size=image_size, device=device, schedule_name=schedule_name, **kwargs)
242242
logger.warning(msg=f"[{device}]: Setting sample error, we has been automatically set to ddpm.")

0 commit comments

Comments
 (0)