Skip to content

Commit 3ae3935

Browse files
committed
Fix a warning and add more tests
1 parent 3226259 commit 3ae3935

File tree

2 files changed

+200
-1
lines changed

2 files changed

+200
-1
lines changed

pygad/visualize/plot.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -117,7 +117,10 @@ def plot_fitness(self,
117117
matplt.xlabel(xlabel, fontsize=font_size)
118118
matplt.ylabel(ylabel, fontsize=font_size)
119119
# Create a legend out of the labels.
120-
matplt.legend()
120+
# Check if there is at least 1 labeled artist.
121+
# If not, the matplt.legend() method will raise a warning.
122+
if not (matplt.gca().get_legend_handles_labels()[0] == []):
123+
matplt.legend()
121124

122125
if not save_dir is None:
123126
matplt.savefig(fname=save_dir,

tests/test_visualize.py

Lines changed: 196 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,196 @@
1+
import pygad
2+
import numpy
3+
import os
4+
import matplotlib
5+
# Use Agg backend for headless testing (no GUI needed)
6+
matplotlib.use('Agg')
7+
import matplotlib.pyplot as plt
8+
9+
# Global constants for testing
10+
num_generations = 5
11+
num_parents_mating = 4
12+
sol_per_pop = 10
13+
num_genes = 3
14+
random_seed = 42
15+
16+
def fitness_func(ga_instance, solution, solution_idx):
17+
return numpy.sum(solution**2)
18+
19+
def fitness_func_multi(ga_instance, solution, solution_idx):
20+
return [numpy.sum(solution**2), numpy.sum(solution)]
21+
22+
def test_plot_fitness_parameters():
23+
"""Test all parameters of plot_fitness()."""
24+
ga_instance = pygad.GA(num_generations=num_generations,
25+
num_parents_mating=num_parents_mating,
26+
fitness_func=fitness_func,
27+
sol_per_pop=sol_per_pop,
28+
num_genes=num_genes,
29+
random_seed=random_seed,
30+
suppress_warnings=True
31+
)
32+
ga_instance.run()
33+
34+
# Test different plot types
35+
for p_type in ["plot", "scatter", "bar"]:
36+
fig = ga_instance.plot_fitness(plot_type=p_type,
37+
title=f"Title {p_type}",
38+
xlabel="X", ylabel="Y",
39+
linewidth=2, font_size=12, color="blue")
40+
assert isinstance(fig, matplotlib.figure.Figure)
41+
plt.close(fig)
42+
43+
# Test multi-objective specific parameters
44+
ga_multi = pygad.GA(num_generations=2,
45+
num_parents_mating=2,
46+
fitness_func=fitness_func_multi,
47+
sol_per_pop=5,
48+
num_genes=3,
49+
parent_selection_type="nsga2",
50+
suppress_warnings=True)
51+
ga_multi.run()
52+
53+
fig = ga_multi.plot_fitness(linewidth=[2, 4],
54+
color=["blue", "green"],
55+
label=["Obj A", "Obj B"])
56+
assert isinstance(fig, matplotlib.figure.Figure)
57+
plt.close(fig)
58+
print("test_plot_fitness_parameters passed.")
59+
60+
def test_plot_new_solution_rate_parameters():
61+
"""Test all parameters of plot_new_solution_rate() and its validation."""
62+
ga_instance = pygad.GA(num_generations=num_generations,
63+
num_parents_mating=num_parents_mating,
64+
fitness_func=fitness_func,
65+
sol_per_pop=sol_per_pop,
66+
num_genes=num_genes,
67+
random_seed=random_seed,
68+
save_solutions=True,
69+
suppress_warnings=True
70+
)
71+
ga_instance.run()
72+
73+
# Test different plot types and parameters
74+
for p_type in ["plot", "scatter", "bar"]:
75+
fig = ga_instance.plot_new_solution_rate(title=f"Rate {p_type}",
76+
plot_type=p_type,
77+
linewidth=2, color="purple")
78+
assert isinstance(fig, matplotlib.figure.Figure)
79+
plt.close(fig)
80+
81+
# Validation: Test error when save_solutions=False
82+
ga_instance_no_save = pygad.GA(num_generations=1,
83+
num_parents_mating=1,
84+
fitness_func=fitness_func,
85+
sol_per_pop=5,
86+
num_genes=2,
87+
save_solutions=False,
88+
suppress_warnings=True)
89+
ga_instance_no_save.run()
90+
try:
91+
ga_instance_no_save.plot_new_solution_rate()
92+
except RuntimeError:
93+
print("plot_new_solution_rate validation caught.")
94+
95+
print("test_plot_new_solution_rate_parameters passed.")
96+
97+
def test_plot_genes_parameters():
98+
"""Test all parameters of plot_genes()."""
99+
ga_instance = pygad.GA(num_generations=num_generations,
100+
num_parents_mating=num_parents_mating,
101+
fitness_func=fitness_func,
102+
sol_per_pop=sol_per_pop,
103+
num_genes=num_genes,
104+
random_seed=random_seed,
105+
save_solutions=True,
106+
save_best_solutions=True,
107+
suppress_warnings=True
108+
)
109+
ga_instance.run()
110+
111+
# Test different graph types and parameters
112+
for g_type in ["plot", "boxplot", "histogram"]:
113+
fig = ga_instance.plot_genes(graph_type=g_type, fill_color="yellow", color="black")
114+
assert isinstance(fig, matplotlib.figure.Figure)
115+
plt.close(fig)
116+
117+
# Test solutions="best"
118+
fig = ga_instance.plot_genes(solutions="best")
119+
assert isinstance(fig, matplotlib.figure.Figure)
120+
plt.close(fig)
121+
122+
print("test_plot_genes_parameters passed.")
123+
124+
def test_plot_pareto_front_curve_parameters():
125+
"""Test all parameters of plot_pareto_front_curve() and its validation."""
126+
ga_instance = pygad.GA(num_generations=num_generations,
127+
num_parents_mating=num_parents_mating,
128+
fitness_func=fitness_func_multi,
129+
sol_per_pop=sol_per_pop,
130+
num_genes=num_genes,
131+
random_seed=random_seed,
132+
parent_selection_type="nsga2",
133+
suppress_warnings=True
134+
)
135+
ga_instance.run()
136+
137+
fig = ga_instance.plot_pareto_front_curve(title="Pareto",
138+
linewidth=4,
139+
label="Frontier",
140+
color="red",
141+
color_fitness="black",
142+
grid=False,
143+
alpha=0.5,
144+
marker="x")
145+
assert isinstance(fig, matplotlib.figure.Figure)
146+
plt.close(fig)
147+
148+
# Validation: Test error for single-objective
149+
ga_instance_single = pygad.GA(num_generations=1,
150+
num_parents_mating=1,
151+
fitness_func=fitness_func,
152+
sol_per_pop=5,
153+
num_genes=2,
154+
suppress_warnings=True)
155+
ga_instance_single.run()
156+
try:
157+
ga_instance_single.plot_pareto_front_curve()
158+
except RuntimeError:
159+
print("plot_pareto_front_curve validation (multi-objective required) caught.")
160+
161+
print("test_plot_pareto_front_curve_parameters passed.")
162+
163+
def test_visualize_save_dir():
164+
"""Test save_dir parameter for all methods."""
165+
ga_instance = pygad.GA(num_generations=2,
166+
num_parents_mating=2,
167+
fitness_func=fitness_func,
168+
sol_per_pop=5,
169+
num_genes=2,
170+
save_solutions=True,
171+
suppress_warnings=True
172+
)
173+
ga_instance.run()
174+
175+
methods = [
176+
(ga_instance.plot_fitness, {}),
177+
(ga_instance.plot_new_solution_rate, {}),
178+
(ga_instance.plot_genes, {"graph_type": "plot"})
179+
]
180+
181+
for method, kwargs in methods:
182+
filename = f"test_{method.__name__}.png"
183+
if os.path.exists(filename): os.remove(filename)
184+
method(save_dir=filename, **kwargs)
185+
assert os.path.exists(filename)
186+
os.remove(filename)
187+
188+
print("test_visualize_save_dir passed.")
189+
190+
if __name__ == "__main__":
191+
test_plot_fitness_parameters()
192+
test_plot_new_solution_rate_parameters()
193+
test_plot_genes_parameters()
194+
test_plot_pareto_front_curve_parameters()
195+
test_visualize_save_dir()
196+
print("\nAll visualization tests passed!")

0 commit comments

Comments
 (0)