@@ -117,15 +117,11 @@ def _sample_loop(
117117
118118 # DPM++ 2M (second-order)
119119 if self .order == 2 :
120- if i < len (time_steps ) - 1 :
121- # Intermediate noise prediction
122- x_inter = torch .sqrt (alpha_prev ) * x0 + torch .sqrt (
123- 1 - alpha_prev - sigma ** 2 ) * predicted_noise
124- predicted_noise_inter = self ._get_predicted_noise (model , x_inter , t_prev_tensor , labels ,
125- cfg_scale )
126-
127- # Second-order correction
128- predicted_noise = (3 * predicted_noise - predicted_noise_inter ) / 2
120+ # 2nd-order correction: only needs model at t_prev, available on all steps
121+ sqrt_term = torch .sqrt (torch .clamp (1 - alpha_prev - sigma ** 2 , min = 1e-8 ))
122+ x_inter = torch .sqrt (alpha_prev ) * x0 + sqrt_term * predicted_noise
123+ predicted_noise_inter = self ._get_predicted_noise (model , x_inter , t_prev_tensor , labels , cfg_scale )
124+ predicted_noise = (3 * predicted_noise - predicted_noise_inter ) / 2
129125
130126 # DPM++ 3M (third-order)
131127 elif self .order == 3 :
@@ -134,20 +130,25 @@ def _sample_loop(
134130 t_next_tensor = (torch .ones (n ) * t_next ).long ().to (self .device )
135131 alpha_next = self .alpha_hat [t_next_tensor ][:, None , None , None ]
136132
137- # First intermediate step, 1e-8 to avoid NaN
133+ # First intermediate step
138134 sqrt_term1 = torch .sqrt (torch .clamp (1 - alpha_prev - sigma ** 2 , min = 1e-8 ))
139135 x_inter1 = torch .sqrt (alpha_prev ) * x0 + sqrt_term1 * predicted_noise
140136 pred_noise1 = self ._get_predicted_noise (model , x_inter1 , t_prev_tensor , labels , cfg_scale )
141137
142- # Second intermediate step, 1e-8 to avoid NaN
138+ # Second intermediate step
143139 sqrt_term2 = torch .sqrt (torch .clamp (1 - alpha_next - sigma ** 2 , min = 1e-8 ))
144140 x_inter2 = torch .sqrt (alpha_next ) * x0 + sqrt_term2 * pred_noise1
145141 pred_noise2 = self ._get_predicted_noise (model , x_inter2 , t_next_tensor , labels , cfg_scale )
146142
147143 # Third-order correction
148144 predicted_noise = (23 * predicted_noise - 16 * pred_noise1 + 5 * pred_noise2 ) / 12
149- # Or use a more stable variant
150- # predicted_noise = (18 * predicted_noise - 12 * pred_noise1 + 3 * pred_noise2) / 9
145+ else :
146+ # Last step: no look-ahead available, fallback to 2nd-order correction
147+ sqrt_term_inter = torch .sqrt (torch .clamp (1 - alpha_prev - sigma ** 2 , min = 1e-8 ))
148+ x_inter = torch .sqrt (alpha_prev ) * x0 + sqrt_term_inter * predicted_noise
149+ predicted_noise_inter = self ._get_predicted_noise (model , x_inter , t_prev_tensor , labels ,
150+ cfg_scale )
151+ predicted_noise = (3 * predicted_noise - predicted_noise_inter ) / 2
151152
152153 # Add noise for stochastic sampling
153154 noise = torch .randn_like (x ) if t > 1 else torch .zeros_like (x )
0 commit comments