Skip to content

Commit 04d5df6

Browse files
committed
fix(scheduler): device mismatch from upstream b114620 - move rk and b to device before torch.stack
1 parent 06e30e7 commit 04d5df6

1 file changed

Lines changed: 4 additions & 4 deletions

File tree

src/diffusers/schedulers/scheduling_unipc_multistep.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -900,7 +900,7 @@ def multistep_uni_p_bh_update(
900900
alpha_si, sigma_si = self._sigma_to_alpha_sigma_t(self.sigmas[si])
901901
lambda_si = torch.log(alpha_si) - torch.log(sigma_si)
902902
rk = (lambda_si - lambda_s0) / h
903-
rks.append(rk)
903+
rks.append(rk.to(device))
904904
D1s.append((mi - m0) / rk)
905905

906906
rks.append(torch.ones((), device=device))
@@ -924,7 +924,7 @@ def multistep_uni_p_bh_update(
924924

925925
for i in range(1, order + 1):
926926
R.append(torch.pow(rks, i - 1))
927-
b.append(h_phi_k * factorial_i / B_h)
927+
b.append((h_phi_k * factorial_i / B_h).to(device))
928928
factorial_i *= i + 1
929929
h_phi_k = h_phi_k / hh - 1 / factorial_i
930930

@@ -1035,7 +1035,7 @@ def multistep_uni_c_bh_update(
10351035
alpha_si, sigma_si = self._sigma_to_alpha_sigma_t(self.sigmas[si])
10361036
lambda_si = torch.log(alpha_si) - torch.log(sigma_si)
10371037
rk = (lambda_si - lambda_s0) / h
1038-
rks.append(rk)
1038+
rks.append(rk.to(device))
10391039
D1s.append((mi - m0) / rk)
10401040

10411041
rks.append(torch.ones((), device=device))
@@ -1059,7 +1059,7 @@ def multistep_uni_c_bh_update(
10591059

10601060
for i in range(1, order + 1):
10611061
R.append(torch.pow(rks, i - 1))
1062-
b.append(h_phi_k * factorial_i / B_h)
1062+
b.append((h_phi_k * factorial_i / B_h).to(device))
10631063
factorial_i *= i + 1
10641064
h_phi_k = h_phi_k / hh - 1 / factorial_i
10651065

0 commit comments

Comments
 (0)