@@ -900,7 +900,7 @@ def multistep_uni_p_bh_update(
900900 alpha_si , sigma_si = self ._sigma_to_alpha_sigma_t (self .sigmas [si ])
901901 lambda_si = torch .log (alpha_si ) - torch .log (sigma_si )
902902 rk = (lambda_si - lambda_s0 ) / h
903- rks .append (rk )
903+ rks .append (rk . to ( device ) )
904904 D1s .append ((mi - m0 ) / rk )
905905
906906 rks .append (torch .ones ((), device = device ))
@@ -924,7 +924,7 @@ def multistep_uni_p_bh_update(
924924
925925 for i in range (1 , order + 1 ):
926926 R .append (torch .pow (rks , i - 1 ))
927- b .append (h_phi_k * factorial_i / B_h )
927+ b .append (( h_phi_k * factorial_i / B_h ). to ( device ) )
928928 factorial_i *= i + 1
929929 h_phi_k = h_phi_k / hh - 1 / factorial_i
930930
@@ -1035,7 +1035,7 @@ def multistep_uni_c_bh_update(
10351035 alpha_si , sigma_si = self ._sigma_to_alpha_sigma_t (self .sigmas [si ])
10361036 lambda_si = torch .log (alpha_si ) - torch .log (sigma_si )
10371037 rk = (lambda_si - lambda_s0 ) / h
1038- rks .append (rk )
1038+ rks .append (rk . to ( device ) )
10391039 D1s .append ((mi - m0 ) / rk )
10401040
10411041 rks .append (torch .ones ((), device = device ))
@@ -1059,7 +1059,7 @@ def multistep_uni_c_bh_update(
10591059
10601060 for i in range (1 , order + 1 ):
10611061 R .append (torch .pow (rks , i - 1 ))
1062- b .append (h_phi_k * factorial_i / B_h )
1062+ b .append (( h_phi_k * factorial_i / B_h ). to ( device ) )
10631063 factorial_i *= i + 1
10641064 h_phi_k = h_phi_k / hh - 1 / factorial_i
10651065
0 commit comments