@@ -882,29 +882,29 @@ def multistep_uni_p_bh_update(
882882 x_t = self .solver_p .step (model_output , s0 , x ).prev_sample
883883 return x_t
884884
885- sigma_t , sigma_s0 = self .sigmas [self .step_index + 1 ], self .sigmas [self .step_index ]
885+ device = sample .device
886+ sigma_t , sigma_s0 = self .sigmas [self .step_index + 1 ].to (device ), self .sigmas [self .step_index ].to (device )
886887 alpha_t , sigma_t = self ._sigma_to_alpha_sigma_t (sigma_t )
887888 alpha_s0 , sigma_s0 = self ._sigma_to_alpha_sigma_t (sigma_s0 )
888889
889890 lambda_t = torch .log (alpha_t ) - torch .log (sigma_t )
890891 lambda_s0 = torch .log (alpha_s0 ) - torch .log (sigma_s0 )
891892
892893 h = lambda_t - lambda_s0
893- device = sample .device
894894
895895 rks = []
896896 D1s = []
897897 for i in range (1 , order ):
898898 si = self .step_index - i
899899 mi = model_output_list [- (i + 1 )]
900- alpha_si , sigma_si = self ._sigma_to_alpha_sigma_t (self .sigmas [si ])
900+ alpha_si , sigma_si = self ._sigma_to_alpha_sigma_t (self .sigmas [si ]. to ( device ) )
901901 lambda_si = torch .log (alpha_si ) - torch .log (sigma_si )
902902 rk = (lambda_si - lambda_s0 ) / h
903903 rks .append (rk )
904904 D1s .append ((mi - m0 ) / rk )
905905
906- rks .append (1.0 )
907- rks = torch .tensor (rks , device = device )
906+ rks .append (torch . ones ((), device = device ) )
907+ rks = torch .stack (rks )
908908
909909 R = []
910910 b = []
@@ -929,13 +929,13 @@ def multistep_uni_p_bh_update(
929929 h_phi_k = h_phi_k / hh - 1 / factorial_i
930930
931931 R = torch .stack (R )
932- b = torch .tensor (b , device = device )
932+ b = torch .stack ( b ) if len ( b ) > 0 else torch . tensor (b , device = device )
933933
934934 if len (D1s ) > 0 :
935935 D1s = torch .stack (D1s , dim = 1 ) # (B, K)
936936 # for order 2, we use a simplified version
937937 if order == 2 :
938- rhos_p = torch .tensor ([ 0.5 ] , dtype = x .dtype , device = device )
938+ rhos_p = torch .ones ( 1 , dtype = x .dtype , device = device ) * 0.5
939939 else :
940940 rhos_p = torch .linalg .solve (R [:- 1 , :- 1 ], b [:- 1 ]).to (device ).to (x .dtype )
941941 else :
@@ -1017,29 +1017,29 @@ def multistep_uni_c_bh_update(
10171017 x_t = this_sample
10181018 model_t = this_model_output
10191019
1020- sigma_t , sigma_s0 = self .sigmas [self .step_index ], self .sigmas [self .step_index - 1 ]
1020+ device = this_sample .device
1021+ sigma_t , sigma_s0 = self .sigmas [self .step_index ].to (device ), self .sigmas [self .step_index - 1 ].to (device )
10211022 alpha_t , sigma_t = self ._sigma_to_alpha_sigma_t (sigma_t )
10221023 alpha_s0 , sigma_s0 = self ._sigma_to_alpha_sigma_t (sigma_s0 )
10231024
10241025 lambda_t = torch .log (alpha_t ) - torch .log (sigma_t )
10251026 lambda_s0 = torch .log (alpha_s0 ) - torch .log (sigma_s0 )
10261027
10271028 h = lambda_t - lambda_s0
1028- device = this_sample .device
10291029
10301030 rks = []
10311031 D1s = []
10321032 for i in range (1 , order ):
10331033 si = self .step_index - (i + 1 )
10341034 mi = model_output_list [- (i + 1 )]
1035- alpha_si , sigma_si = self ._sigma_to_alpha_sigma_t (self .sigmas [si ])
1035+ alpha_si , sigma_si = self ._sigma_to_alpha_sigma_t (self .sigmas [si ]. to ( device ) )
10361036 lambda_si = torch .log (alpha_si ) - torch .log (sigma_si )
10371037 rk = (lambda_si - lambda_s0 ) / h
10381038 rks .append (rk )
10391039 D1s .append ((mi - m0 ) / rk )
10401040
1041- rks .append (1.0 )
1042- rks = torch .tensor (rks , device = device )
1041+ rks .append (torch . ones ((), device = device ) )
1042+ rks = torch .stack (rks )
10431043
10441044 R = []
10451045 b = []
@@ -1064,7 +1064,7 @@ def multistep_uni_c_bh_update(
10641064 h_phi_k = h_phi_k / hh - 1 / factorial_i
10651065
10661066 R = torch .stack (R )
1067- b = torch .tensor (b , device = device )
1067+ b = torch .stack ( b ) if len ( b ) > 0 else torch . tensor (b , device = device )
10681068
10691069 if len (D1s ) > 0 :
10701070 D1s = torch .stack (D1s , dim = 1 )
@@ -1073,7 +1073,7 @@ def multistep_uni_c_bh_update(
10731073
10741074 # for order 1, we use a simplified version
10751075 if order == 1 :
1076- rhos_c = torch .tensor ([ 0.5 ] , dtype = x .dtype , device = device )
1076+ rhos_c = torch .ones ( 1 , dtype = x .dtype , device = device ) * 0.5
10771077 else :
10781078 rhos_c = torch .linalg .solve (R , b ).to (device ).to (x .dtype )
10791079
0 commit comments