@@ -32,15 +32,15 @@ def fitness_func(ga_instance, solution, solution_idx):
3232 # To avoid updating all, we can update just the one we need.
3333
3434 # However, for simplicity and to test GANN's intended flow:
35- population_matrices = pygad .gann .population_as_matrices (num_networks = ga_instance . sol_per_pop ,
35+ population_matrices = pygad .gann .population_as_matrices (population_networks = gann_instance . population_networks ,
3636 population_vectors = ga_instance .population )
3737 gann_instance .update_population_trained_weights (population_trained_weights = population_matrices )
3838
3939 predictions = pygad .nn .predict (last_layer = gann_instance .population_networks [solution_idx ],
4040 data_inputs = data_inputs )
4141
4242 # Mean Absolute Error
43- abs_error = numpy .mean (numpy .abs (predictions . flatten () - data_outputs )) + 0.00000001
43+ abs_error = numpy .mean (numpy .abs (predictions - data_outputs )) + 0.00000001
4444 fitness = 1.0 / abs_error
4545 return fitness
4646
@@ -65,8 +65,8 @@ def test_nn_direct_usage():
6565 data_inputs = numpy .array ([[0.1 , 0.2 , 0.3 ]])
6666 predictions = pygad .nn .predict (last_layer = output_layer , data_inputs = data_inputs )
6767
68- assert predictions . shape == ( 1 , 1 )
69- assert 0 <= predictions [0 , 0 ] <= 1
68+ assert len ( predictions ) == 1
69+ assert 0 <= predictions [0 ] <= 1
7070 print ("test_nn_direct_usage passed." )
7171
7272if __name__ == "__main__" :
0 commit comments