@@ -900,11 +900,11 @@ def multistep_uni_p_bh_update(
900900 alpha_si , sigma_si = self ._sigma_to_alpha_sigma_t (self .sigmas [si ])
901901 lambda_si = torch .log (alpha_si ) - torch .log (sigma_si )
902902 rk = (lambda_si - lambda_s0 ) / h
903- rks .append (rk . to ( device ) )
903+ rks .append (rk )
904904 D1s .append ((mi - m0 ) / rk )
905905
906- rks .append (torch . ones ((), device = device ) )
907- rks = torch .stack (rks )
906+ rks .append (1.0 )
907+ rks = torch .tensor (rks , device = device )
908908
909909 R = []
910910 b = []
@@ -924,18 +924,18 @@ def multistep_uni_p_bh_update(
924924
925925 for i in range (1 , order + 1 ):
926926 R .append (torch .pow (rks , i - 1 ))
927- b .append (( h_phi_k * factorial_i / B_h ). to ( device ) )
927+ b .append (h_phi_k * factorial_i / B_h )
928928 factorial_i *= i + 1
929929 h_phi_k = h_phi_k / hh - 1 / factorial_i
930930
931931 R = torch .stack (R )
932- b = torch .stack ( b ) if len ( b ) > 0 else torch . tensor (b , device = device )
932+ b = torch .tensor (b , device = device )
933933
934934 if len (D1s ) > 0 :
935935 D1s = torch .stack (D1s , dim = 1 ) # (B, K)
936936 # for order 2, we use a simplified version
937937 if order == 2 :
938- rhos_p = torch .ones ( 1 , dtype = x .dtype , device = device ) * 0.5
938+ rhos_p = torch .tensor ([ 0.5 ] , dtype = x .dtype , device = device )
939939 else :
940940 rhos_p = torch .linalg .solve (R [:- 1 , :- 1 ], b [:- 1 ]).to (device ).to (x .dtype )
941941 else :
@@ -1035,11 +1035,11 @@ def multistep_uni_c_bh_update(
10351035 alpha_si , sigma_si = self ._sigma_to_alpha_sigma_t (self .sigmas [si ])
10361036 lambda_si = torch .log (alpha_si ) - torch .log (sigma_si )
10371037 rk = (lambda_si - lambda_s0 ) / h
1038- rks .append (rk . to ( device ) )
1038+ rks .append (rk )
10391039 D1s .append ((mi - m0 ) / rk )
10401040
1041- rks .append (torch . ones ((), device = device ) )
1042- rks = torch .stack (rks )
1041+ rks .append (1.0 )
1042+ rks = torch .tensor (rks , device = device )
10431043
10441044 R = []
10451045 b = []
@@ -1059,12 +1059,12 @@ def multistep_uni_c_bh_update(
10591059
10601060 for i in range (1 , order + 1 ):
10611061 R .append (torch .pow (rks , i - 1 ))
1062- b .append (( h_phi_k * factorial_i / B_h ). to ( device ) )
1062+ b .append (h_phi_k * factorial_i / B_h )
10631063 factorial_i *= i + 1
10641064 h_phi_k = h_phi_k / hh - 1 / factorial_i
10651065
10661066 R = torch .stack (R )
1067- b = torch .stack ( b ) if len ( b ) > 0 else torch . tensor (b , device = device )
1067+ b = torch .tensor (b , device = device )
10681068
10691069 if len (D1s ) > 0 :
10701070 D1s = torch .stack (D1s , dim = 1 )
@@ -1073,7 +1073,7 @@ def multistep_uni_c_bh_update(
10731073
10741074 # for order 1, we use a simplified version
10751075 if order == 1 :
1076- rhos_c = torch .ones ( 1 , dtype = x .dtype , device = device ) * 0.5
1076+ rhos_c = torch .tensor ([ 0.5 ] , dtype = x .dtype , device = device )
10771077 else :
10781078 rhos_c = torch .linalg .solve (R , b ).to (device ).to (x .dtype )
10791079
0 commit comments