@@ -202,14 +202,14 @@ def test_KMN_with_2d_gaussian_sampling(self):
202202 x_cond = 5 * np .ones (shape = (2000000 ,1 ))
203203 _ , y_sample = model .sample (x_cond )
204204 print (np .mean (y_sample ), np .std (y_sample ))
205- self .assertAlmostEqual (np .mean (y_sample ), float ( model .mean_ (x_cond [1 ])), places = 1 )
206- self .assertAlmostEqual (np .std (y_sample ), float ( model .covariance (x_cond [1 ])), places = 1 )
205+ self .assertAlmostEqual (np .mean (y_sample ), model .mean_ (x_cond [1 ]). item ( ), places = 1 )
206+ self .assertAlmostEqual (np .std (y_sample ), model .covariance (x_cond [1 ]). item ( ), places = 1 )
207207
208208 x_cond = np .ones (shape = (400000 , 1 ))
209209 x_cond [0 ,0 ] = 5.0
210210 _ , y_sample = model .sample (x_cond )
211- self .assertAlmostEqual (np .mean (y_sample ), float ( model .mean_ (x_cond [1 ])), places = 1 )
212- self .assertAlmostEqual (np .std (y_sample ), float ( np .sqrt (model .covariance (x_cond [1 ]))), places = 1 )
211+ self .assertAlmostEqual (np .mean (y_sample ), model .mean_ (x_cond [1 ]). item ( ), places = 1 )
212+ self .assertAlmostEqual (np .std (y_sample ), np .sqrt (model .covariance (x_cond [1 ])). item ( ), places = 1 )
213213
214214 def test_MDN_with_2d_gaussian_sampling (self ):
215215 X , Y = self .get_samples ()
@@ -219,8 +219,8 @@ def test_MDN_with_2d_gaussian_sampling(self):
219219
220220 x_cond = np .ones (shape = (10 ** 6 ,1 ))
221221 _ , y_sample = model .sample (x_cond )
222- self .assertAlmostEqual (np .mean (y_sample ), float ( model .mean_ (y_sample [1 ])), places = 0 )
223- self .assertAlmostEqual (np .std (y_sample ), float ( model .covariance (y_sample [1 ])), places = 0 )
222+ self .assertAlmostEqual (np .mean (y_sample ), model .mean_ (y_sample [1 ]). item ( ), places = 0 )
223+ self .assertAlmostEqual (np .std (y_sample ), model .covariance (y_sample [1 ]). item ( ), places = 0 )
224224
225225 def test_MDN_with_2d_gaussian (self ):
226226 mu = 200
@@ -419,8 +419,8 @@ def test7_data_normalization(self):
419419 # test if data statistics were properly assigned to tf graph
420420 x_mean , x_std = model .x_mean , model .x_std
421421 print (x_mean , x_std )
422- mean_diff = float ( np .abs (x_mean - 20 ))
423- std_diff = float ( np .abs (x_std - 2 ) )
422+ mean_diff = np .abs (x_mean - 20 ). item ( )
423+ std_diff = np .abs (x_std - 2 ). item ( )
424424 self .assertLessEqual (mean_diff , 0.5 )
425425 self .assertLessEqual (std_diff , 0.5 )
426426
@@ -481,14 +481,14 @@ def test_MDN_adaptive_noise(self):
481481 hidden_sizes = (8 , 8 ),
482482 adaptive_noise_fn = adaptive_noise_fn , n_training_epochs = 500 )
483483 est .fit (X , Y )
484- std_999 = est .std_ (x_cond = np .array ([[0.0 ]]))[0 ]
484+ std_999 = est .std_ (x_cond = np .array ([[0.0 ]]))[0 ]. item ()
485485
486486 X , Y = self .get_samples (mu = 0 , std = 1 , n_samples = 1002 )
487487 est = MixtureDensityNetwork ("mdn_adaptive_noise_1002" , 1 , 1 , n_centers = 1 , y_noise_std = 0.0 , x_noise_std = 0.0 ,
488488 hidden_sizes = (8 , 8 ),
489489 adaptive_noise_fn = adaptive_noise_fn , n_training_epochs = 500 )
490490 est .fit (X , Y )
491- std_1002 = est .std_ (x_cond = np .array ([[0.0 ]]))[0 ]
491+ std_1002 = est .std_ (x_cond = np .array ([[0.0 ]]))[0 ]. item ()
492492
493493 self .assertLess (std_999 , std_1002 )
494494 self .assertGreater (std_1002 , 2 )
@@ -501,14 +501,14 @@ def test_NF_adaptive_noise(self):
501501 x_noise_std = 0.0 , adaptive_noise_fn = adaptive_noise_fn ,
502502 n_training_epochs = 500 )
503503 est .fit (X , Y )
504- std_999 = est .std_ (x_cond = np .array ([[0.0 ]]))[0 ]
504+ std_999 = est .std_ (x_cond = np .array ([[0.0 ]]))[0 ]. item ()
505505
506506 X , Y = self .get_samples (mu = 0 , std = 1 , n_samples = 1002 )
507507 est = NormalizingFlowEstimator ("nf_1002" , 1 , 1 , y_noise_std = 0.0 , n_flows = 2 , hidden_sizes = (8 ,8 ),
508508 x_noise_std = 0.0 , adaptive_noise_fn = adaptive_noise_fn ,
509509 n_training_epochs = 500 )
510510 est .fit (X , Y )
511- std_1002 = est .std_ (x_cond = np .array ([[0.0 ]]))[0 ]
511+ std_1002 = est .std_ (x_cond = np .array ([[0.0 ]]))[0 ]. item ()
512512
513513 self .assertLess (std_999 , std_1002 )
514514 self .assertGreater (std_1002 , 2 )
@@ -552,7 +552,7 @@ def test_KMN_l2_regularization(self):
552552 err_no_reg = np .mean (np .abs (kmn_no_reg .pdf (x , y ) - p_true ))
553553 err_reg_l2 = np .mean (np .abs (kmn_reg_l2 .pdf (x , y ) - p_true ))
554554
555- self .assertLessEqual (err_reg_l2 , err_no_reg + 1e-3 )
555+ self .assertLessEqual (err_reg_l2 , err_no_reg + 1e-2 )
556556
557557 def test_NF_l1_regularization (self ):
558558 mu = 5
0 commit comments