Skip to content

Commit 3226259

Browse files
committed
Fix a bug
1 parent 03b112a commit 3226259

File tree

1 file changed

+4
-4
lines changed

1 file changed

+4
-4
lines changed

tests/test_gann.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -32,15 +32,15 @@ def fitness_func(ga_instance, solution, solution_idx):
3232
# To avoid updating all, we can update just the one we need.
3333

3434
# However, for simplicity and to test GANN's intended flow:
35-
population_matrices = pygad.gann.population_as_matrices(num_networks=ga_instance.sol_per_pop,
35+
population_matrices = pygad.gann.population_as_matrices(population_networks=gann_instance.population_networks,
3636
population_vectors=ga_instance.population)
3737
gann_instance.update_population_trained_weights(population_trained_weights=population_matrices)
3838

3939
predictions = pygad.nn.predict(last_layer=gann_instance.population_networks[solution_idx],
4040
data_inputs=data_inputs)
4141

4242
# Mean Absolute Error
43-
abs_error = numpy.mean(numpy.abs(predictions.flatten() - data_outputs)) + 0.00000001
43+
abs_error = numpy.mean(numpy.abs(predictions - data_outputs)) + 0.00000001
4444
fitness = 1.0 / abs_error
4545
return fitness
4646

@@ -65,8 +65,8 @@ def test_nn_direct_usage():
6565
data_inputs = numpy.array([[0.1, 0.2, 0.3]])
6666
predictions = pygad.nn.predict(last_layer=output_layer, data_inputs=data_inputs)
6767

68-
assert predictions.shape == (1, 1)
69-
assert 0 <= predictions[0, 0] <= 1
68+
assert len(predictions) == 1
69+
assert 0 <= predictions[0] <= 1
7070
print("test_nn_direct_usage passed.")
7171

7272
if __name__ == "__main__":

0 commit comments

Comments
 (0)