Skip to content

Commit 770abee

Browse files
author
Ting-Yun Chang
committed
revert scheduler
1 parent 119822f commit 770abee

1 file changed

Lines changed: 14 additions & 14 deletions

File tree

src/diffusers/schedulers/scheduling_unipc_multistep.py

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -882,29 +882,29 @@ def multistep_uni_p_bh_update(
882882
x_t = self.solver_p.step(model_output, s0, x).prev_sample
883883
return x_t
884884

885-
sigma_t, sigma_s0 = self.sigmas[self.step_index + 1], self.sigmas[self.step_index]
885+
device = sample.device
886+
sigma_t, sigma_s0 = self.sigmas[self.step_index + 1].to(device), self.sigmas[self.step_index].to(device)
886887
alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t)
887888
alpha_s0, sigma_s0 = self._sigma_to_alpha_sigma_t(sigma_s0)
888889

889890
lambda_t = torch.log(alpha_t) - torch.log(sigma_t)
890891
lambda_s0 = torch.log(alpha_s0) - torch.log(sigma_s0)
891892

892893
h = lambda_t - lambda_s0
893-
device = sample.device
894894

895895
rks = []
896896
D1s = []
897897
for i in range(1, order):
898898
si = self.step_index - i
899899
mi = model_output_list[-(i + 1)]
900-
alpha_si, sigma_si = self._sigma_to_alpha_sigma_t(self.sigmas[si])
900+
alpha_si, sigma_si = self._sigma_to_alpha_sigma_t(self.sigmas[si].to(device))
901901
lambda_si = torch.log(alpha_si) - torch.log(sigma_si)
902902
rk = (lambda_si - lambda_s0) / h
903903
rks.append(rk)
904904
D1s.append((mi - m0) / rk)
905905

906-
rks.append(1.0)
907-
rks = torch.tensor(rks, device=device)
906+
rks.append(torch.ones((), device=device))
907+
rks = torch.stack(rks)
908908

909909
R = []
910910
b = []
@@ -929,13 +929,13 @@ def multistep_uni_p_bh_update(
929929
h_phi_k = h_phi_k / hh - 1 / factorial_i
930930

931931
R = torch.stack(R)
932-
b = torch.tensor(b, device=device)
932+
b = torch.stack(b) if len(b) > 0 else torch.tensor(b, device=device)
933933

934934
if len(D1s) > 0:
935935
D1s = torch.stack(D1s, dim=1) # (B, K)
936936
# for order 2, we use a simplified version
937937
if order == 2:
938-
rhos_p = torch.tensor([0.5], dtype=x.dtype, device=device)
938+
rhos_p = torch.ones(1, dtype=x.dtype, device=device) * 0.5
939939
else:
940940
rhos_p = torch.linalg.solve(R[:-1, :-1], b[:-1]).to(device).to(x.dtype)
941941
else:
@@ -1017,29 +1017,29 @@ def multistep_uni_c_bh_update(
10171017
x_t = this_sample
10181018
model_t = this_model_output
10191019

1020-
sigma_t, sigma_s0 = self.sigmas[self.step_index], self.sigmas[self.step_index - 1]
1020+
device = this_sample.device
1021+
sigma_t, sigma_s0 = self.sigmas[self.step_index].to(device), self.sigmas[self.step_index - 1].to(device)
10211022
alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t)
10221023
alpha_s0, sigma_s0 = self._sigma_to_alpha_sigma_t(sigma_s0)
10231024

10241025
lambda_t = torch.log(alpha_t) - torch.log(sigma_t)
10251026
lambda_s0 = torch.log(alpha_s0) - torch.log(sigma_s0)
10261027

10271028
h = lambda_t - lambda_s0
1028-
device = this_sample.device
10291029

10301030
rks = []
10311031
D1s = []
10321032
for i in range(1, order):
10331033
si = self.step_index - (i + 1)
10341034
mi = model_output_list[-(i + 1)]
1035-
alpha_si, sigma_si = self._sigma_to_alpha_sigma_t(self.sigmas[si])
1035+
alpha_si, sigma_si = self._sigma_to_alpha_sigma_t(self.sigmas[si].to(device))
10361036
lambda_si = torch.log(alpha_si) - torch.log(sigma_si)
10371037
rk = (lambda_si - lambda_s0) / h
10381038
rks.append(rk)
10391039
D1s.append((mi - m0) / rk)
10401040

1041-
rks.append(1.0)
1042-
rks = torch.tensor(rks, device=device)
1041+
rks.append(torch.ones((), device=device))
1042+
rks = torch.stack(rks)
10431043

10441044
R = []
10451045
b = []
@@ -1064,7 +1064,7 @@ def multistep_uni_c_bh_update(
10641064
h_phi_k = h_phi_k / hh - 1 / factorial_i
10651065

10661066
R = torch.stack(R)
1067-
b = torch.tensor(b, device=device)
1067+
b = torch.stack(b) if len(b) > 0 else torch.tensor(b, device=device)
10681068

10691069
if len(D1s) > 0:
10701070
D1s = torch.stack(D1s, dim=1)
@@ -1073,7 +1073,7 @@ def multistep_uni_c_bh_update(
10731073

10741074
# for order 1, we use a simplified version
10751075
if order == 1:
1076-
rhos_c = torch.tensor([0.5], dtype=x.dtype, device=device)
1076+
rhos_c = torch.ones(1, dtype=x.dtype, device=device) * 0.5
10771077
else:
10781078
rhos_c = torch.linalg.solve(R, b).to(device).to(x.dtype)
10791079

0 commit comments

Comments
 (0)