Skip to content

Commit 3cc1210

Browse files
author
Ting-Yun Chang
committed
overwrite scheduling_unipc_multistep.py with main's version
1 parent d4b27f6 commit 3cc1210

1 file changed

Lines changed: 12 additions & 12 deletions

File tree

src/diffusers/schedulers/scheduling_unipc_multistep.py

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -900,11 +900,11 @@ def multistep_uni_p_bh_update(
900900
alpha_si, sigma_si = self._sigma_to_alpha_sigma_t(self.sigmas[si])
901901
lambda_si = torch.log(alpha_si) - torch.log(sigma_si)
902902
rk = (lambda_si - lambda_s0) / h
903-
rks.append(rk.to(device))
903+
rks.append(rk)
904904
D1s.append((mi - m0) / rk)
905905

906-
rks.append(torch.ones((), device=device))
907-
rks = torch.stack(rks)
906+
rks.append(1.0)
907+
rks = torch.tensor(rks, device=device)
908908

909909
R = []
910910
b = []
@@ -924,18 +924,18 @@ def multistep_uni_p_bh_update(
924924

925925
for i in range(1, order + 1):
926926
R.append(torch.pow(rks, i - 1))
927-
b.append((h_phi_k * factorial_i / B_h).to(device))
927+
b.append(h_phi_k * factorial_i / B_h)
928928
factorial_i *= i + 1
929929
h_phi_k = h_phi_k / hh - 1 / factorial_i
930930

931931
R = torch.stack(R)
932-
b = torch.stack(b) if len(b) > 0 else torch.tensor(b, device=device)
932+
b = torch.tensor(b, device=device)
933933

934934
if len(D1s) > 0:
935935
D1s = torch.stack(D1s, dim=1) # (B, K)
936936
# for order 2, we use a simplified version
937937
if order == 2:
938-
rhos_p = torch.ones(1, dtype=x.dtype, device=device) * 0.5
938+
rhos_p = torch.tensor([0.5], dtype=x.dtype, device=device)
939939
else:
940940
rhos_p = torch.linalg.solve(R[:-1, :-1], b[:-1]).to(device).to(x.dtype)
941941
else:
@@ -1035,11 +1035,11 @@ def multistep_uni_c_bh_update(
10351035
alpha_si, sigma_si = self._sigma_to_alpha_sigma_t(self.sigmas[si])
10361036
lambda_si = torch.log(alpha_si) - torch.log(sigma_si)
10371037
rk = (lambda_si - lambda_s0) / h
1038-
rks.append(rk.to(device))
1038+
rks.append(rk)
10391039
D1s.append((mi - m0) / rk)
10401040

1041-
rks.append(torch.ones((), device=device))
1042-
rks = torch.stack(rks)
1041+
rks.append(1.0)
1042+
rks = torch.tensor(rks, device=device)
10431043

10441044
R = []
10451045
b = []
@@ -1059,12 +1059,12 @@ def multistep_uni_c_bh_update(
10591059

10601060
for i in range(1, order + 1):
10611061
R.append(torch.pow(rks, i - 1))
1062-
b.append((h_phi_k * factorial_i / B_h).to(device))
1062+
b.append(h_phi_k * factorial_i / B_h)
10631063
factorial_i *= i + 1
10641064
h_phi_k = h_phi_k / hh - 1 / factorial_i
10651065

10661066
R = torch.stack(R)
1067-
b = torch.stack(b) if len(b) > 0 else torch.tensor(b, device=device)
1067+
b = torch.tensor(b, device=device)
10681068

10691069
if len(D1s) > 0:
10701070
D1s = torch.stack(D1s, dim=1)
@@ -1073,7 +1073,7 @@ def multistep_uni_c_bh_update(
10731073

10741074
# for order 1, we use a simplified version
10751075
if order == 1:
1076-
rhos_c = torch.ones(1, dtype=x.dtype, device=device) * 0.5
1076+
rhos_c = torch.tensor([0.5], dtype=x.dtype, device=device)
10771077
else:
10781078
rhos_c = torch.linalg.solve(R, b).to(device).to(x.dtype)
10791079

0 commit comments

Comments
 (0)