Skip to content

Commit ccbdafd

Browse files
feat(pygal): implement confusion-matrix (#2294)
## Implementation: `confusion-matrix` - pygal Implements the **pygal** version of `confusion-matrix`. **File:** `plots/confusion-matrix/implementations/pygal.py` --- :robot: *[impl-generate workflow](https://github.com/MarkusNeusinger/pyplots/actions/runs/20526593863)* --------- Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com> Co-authored-by: Claude Opus 4.5 <noreply@anthropic.com> Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
1 parent 97c6f62 commit ccbdafd

2 files changed

Lines changed: 378 additions & 0 deletions

File tree

Lines changed: 352 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,352 @@
1+
""" pyplots.ai
2+
confusion-matrix: Confusion Matrix Heatmap
3+
Library: pygal 3.1.0 | Python 3.13.11
4+
Quality: 91/100 | Created: 2025-12-26
5+
"""
6+
7+
import sys
8+
9+
import numpy as np
10+
11+
12+
# Temporarily remove current directory from path to avoid name collision
13+
_cwd = sys.path[0] if sys.path[0] else "."
14+
if _cwd in sys.path:
15+
sys.path.remove(_cwd)
16+
17+
from pygal.graph.graph import Graph # noqa: E402
18+
from pygal.style import Style # noqa: E402
19+
20+
21+
# Restore path
22+
sys.path.insert(0, _cwd)
23+
24+
25+
class ConfusionMatrixChart(Graph):
26+
"""Custom Confusion Matrix Chart for pygal - displays classification results with annotations."""
27+
28+
def __init__(self, *args, **kwargs):
29+
self.matrix_data = kwargs.pop("matrix_data", [])
30+
self.class_labels = kwargs.pop("class_labels", [])
31+
self.colormap = kwargs.pop("colormap", [])
32+
self.show_values = kwargs.pop("show_values", True)
33+
self.x_axis_title = kwargs.pop("x_axis_title", "Predicted Label")
34+
self.y_axis_title = kwargs.pop("y_axis_title", "True Label")
35+
super().__init__(*args, **kwargs)
36+
37+
def _interpolate_color(self, value, min_val, max_val):
38+
"""Interpolate color for smooth gradient."""
39+
if max_val == min_val:
40+
return self.colormap[-1]
41+
42+
# Normalize value to 0-1 range
43+
normalized = (value - min_val) / (max_val - min_val)
44+
normalized = max(0, min(1, normalized))
45+
46+
# Get position in colormap
47+
pos = normalized * (len(self.colormap) - 1)
48+
idx1 = int(pos)
49+
idx2 = min(idx1 + 1, len(self.colormap) - 1)
50+
frac = pos - idx1
51+
52+
# Interpolate between colors
53+
c1 = self.colormap[idx1]
54+
c2 = self.colormap[idx2]
55+
56+
r1, g1, b1 = int(c1[1:3], 16), int(c1[3:5], 16), int(c1[5:7], 16)
57+
r2, g2, b2 = int(c2[1:3], 16), int(c2[3:5], 16), int(c2[5:7], 16)
58+
59+
r = int(r1 + (r2 - r1) * frac)
60+
g = int(g1 + (g2 - g1) * frac)
61+
b = int(b1 + (b2 - b1) * frac)
62+
63+
return f"#{r:02x}{g:02x}{b:02x}"
64+
65+
def _get_text_color(self, bg_color):
66+
"""Get contrasting text color based on background brightness."""
67+
r, g, b = int(bg_color[1:3], 16), int(bg_color[3:5], 16), int(bg_color[5:7], 16)
68+
brightness = (r * 299 + g * 587 + b * 114) / 1000
69+
return "#ffffff" if brightness < 140 else "#333333"
70+
71+
def _plot(self):
72+
"""Draw the confusion matrix."""
73+
if not self.matrix_data:
74+
return
75+
76+
n_classes = len(self.matrix_data)
77+
78+
# Find value range
79+
all_values = [v for row in self.matrix_data for v in row]
80+
min_val = min(all_values)
81+
max_val = max(all_values)
82+
83+
# Get plot dimensions
84+
plot_width = self.view.width
85+
plot_height = self.view.height
86+
87+
# Calculate margins for labels
88+
label_margin_left = 400
89+
label_margin_bottom = 350
90+
label_margin_top = 80
91+
label_margin_right = 280
92+
93+
available_width = plot_width - label_margin_left - label_margin_right
94+
available_height = plot_height - label_margin_bottom - label_margin_top
95+
96+
# Square cells for confusion matrix
97+
cell_size = min(available_width, available_height) / n_classes * 0.92
98+
gap = cell_size * 0.03
99+
100+
# Calculate grid dimensions
101+
grid_size = n_classes * (cell_size + gap) - gap
102+
103+
# Center the grid
104+
x_offset = self.view.x(0) + label_margin_left + (available_width - grid_size) / 2
105+
y_offset = self.view.y(n_classes) + label_margin_top + (available_height - grid_size) / 2
106+
107+
# Create group for the chart
108+
plot_node = self.nodes["plot"]
109+
cm_group = self.svg.node(plot_node, class_="confusion-matrix")
110+
111+
# Draw Y-axis title (True Label) - rotated
112+
y_title_size = 48
113+
y_title_x = x_offset - 320
114+
y_title_y = y_offset + grid_size / 2
115+
text_node = self.svg.node(cm_group, "text", x=y_title_x, y=y_title_y)
116+
text_node.set("text-anchor", "middle")
117+
text_node.set("fill", "#333333")
118+
text_node.set("style", f"font-size:{y_title_size}px;font-weight:bold;font-family:sans-serif")
119+
text_node.set("transform", f"rotate(-90, {y_title_x}, {y_title_y})")
120+
text_node.text = self.y_axis_title
121+
122+
# Draw X-axis title (Predicted Label)
123+
x_title_size = 48
124+
x_title_x = x_offset + grid_size / 2
125+
x_title_y = y_offset + grid_size + 280
126+
text_node = self.svg.node(cm_group, "text", x=x_title_x, y=x_title_y)
127+
text_node.set("text-anchor", "middle")
128+
text_node.set("fill", "#333333")
129+
text_node.set("style", f"font-size:{x_title_size}px;font-weight:bold;font-family:sans-serif")
130+
text_node.text = self.x_axis_title
131+
132+
# Draw row labels (True labels) on the left
133+
row_font_size = min(44, int(cell_size * 0.45))
134+
for i, label in enumerate(self.class_labels):
135+
y = y_offset + i * (cell_size + gap) + cell_size / 2
136+
text_node = self.svg.node(cm_group, "text", x=x_offset - 25, y=y + row_font_size * 0.35)
137+
text_node.set("text-anchor", "end")
138+
text_node.set("fill", "#333333")
139+
text_node.set("style", f"font-size:{row_font_size}px;font-weight:600;font-family:sans-serif")
140+
text_node.text = label
141+
142+
# Draw column labels (Predicted labels) at the bottom - rotated
143+
col_font_size = min(44, int(cell_size * 0.45))
144+
for j, label in enumerate(self.class_labels):
145+
x = x_offset + j * (cell_size + gap) + cell_size / 2
146+
y = y_offset + grid_size + 30
147+
text_node = self.svg.node(cm_group, "text", x=x, y=y)
148+
text_node.set("text-anchor", "start")
149+
text_node.set("fill", "#333333")
150+
text_node.set("style", f"font-size:{col_font_size}px;font-weight:600;font-family:sans-serif")
151+
text_node.set("transform", f"rotate(45, {x}, {y})")
152+
text_node.text = label
153+
154+
# Draw cells with values
155+
value_font_size = min(46, int(cell_size * 0.35))
156+
for i in range(n_classes):
157+
for j in range(n_classes):
158+
value = self.matrix_data[i][j]
159+
color = self._interpolate_color(value, min_val, max_val)
160+
text_color = self._get_text_color(color)
161+
162+
x = x_offset + j * (cell_size + gap)
163+
y = y_offset + i * (cell_size + gap)
164+
165+
# Highlight diagonal (correct predictions) with subtle border
166+
stroke_color = "#306998" if i == j else "#ffffff"
167+
stroke_width = "4" if i == j else "2"
168+
169+
# Draw cell rectangle
170+
rect = self.svg.node(cm_group, "rect", x=x, y=y, width=cell_size, height=cell_size, rx=6, ry=6)
171+
rect.set("fill", color)
172+
rect.set("stroke", stroke_color)
173+
rect.set("stroke-width", stroke_width)
174+
175+
# Add value annotation
176+
if self.show_values:
177+
text_x = x + cell_size / 2
178+
text_y = y + cell_size / 2 + value_font_size * 0.35
179+
180+
text_node = self.svg.node(cm_group, "text", x=text_x, y=text_y)
181+
text_node.set("text-anchor", "middle")
182+
text_node.set("fill", text_color)
183+
text_node.set("style", f"font-size:{value_font_size}px;font-weight:bold;font-family:sans-serif")
184+
text_node.text = str(int(value))
185+
186+
# Draw colorbar on the right
187+
colorbar_width = 55
188+
colorbar_height = grid_size * 0.85
189+
colorbar_x = x_offset + grid_size + 90
190+
colorbar_y = y_offset + (grid_size - colorbar_height) / 2
191+
192+
# Draw gradient colorbar
193+
n_segments = 50
194+
segment_height = colorbar_height / n_segments
195+
for seg_i in range(n_segments):
196+
seg_value = min_val + (max_val - min_val) * (n_segments - 1 - seg_i) / (n_segments - 1)
197+
seg_color = self._interpolate_color(seg_value, min_val, max_val)
198+
seg_y = colorbar_y + seg_i * segment_height
199+
200+
self.svg.node(
201+
cm_group, "rect", x=colorbar_x, y=seg_y, width=colorbar_width, height=segment_height + 1, fill=seg_color
202+
)
203+
204+
# Colorbar border
205+
self.svg.node(
206+
cm_group,
207+
"rect",
208+
x=colorbar_x,
209+
y=colorbar_y,
210+
width=colorbar_width,
211+
height=colorbar_height,
212+
fill="none",
213+
stroke="#333333",
214+
)
215+
216+
# Colorbar labels
217+
cb_label_size = 38
218+
# Max value
219+
text_node = self.svg.node(
220+
cm_group, "text", x=colorbar_x + colorbar_width + 15, y=colorbar_y + cb_label_size * 0.35
221+
)
222+
text_node.set("fill", "#333333")
223+
text_node.set("style", f"font-size:{cb_label_size}px;font-family:sans-serif")
224+
text_node.text = str(int(max_val))
225+
226+
# Mid value
227+
mid_y = colorbar_y + colorbar_height / 2
228+
text_node = self.svg.node(cm_group, "text", x=colorbar_x + colorbar_width + 15, y=mid_y + cb_label_size * 0.35)
229+
text_node.set("fill", "#333333")
230+
text_node.set("style", f"font-size:{cb_label_size}px;font-family:sans-serif")
231+
text_node.text = str(int((min_val + max_val) / 2))
232+
233+
# Min value
234+
text_node = self.svg.node(
235+
cm_group, "text", x=colorbar_x + colorbar_width + 15, y=colorbar_y + colorbar_height + cb_label_size * 0.35
236+
)
237+
text_node.set("fill", "#333333")
238+
text_node.set("style", f"font-size:{cb_label_size}px;font-family:sans-serif")
239+
text_node.text = str(int(min_val))
240+
241+
# Colorbar title
242+
cb_title_size = 42
243+
cb_title_x = colorbar_x + colorbar_width / 2
244+
cb_title_y = colorbar_y - 35
245+
text_node = self.svg.node(cm_group, "text", x=cb_title_x, y=cb_title_y)
246+
text_node.set("text-anchor", "middle")
247+
text_node.set("fill", "#333333")
248+
text_node.set("style", f"font-size:{cb_title_size}px;font-weight:bold;font-family:sans-serif")
249+
text_node.text = "Count"
250+
251+
def _compute(self):
252+
"""Compute the box for rendering."""
253+
n_classes = len(self.matrix_data) if self.matrix_data else 1
254+
self._box.xmin = 0
255+
self._box.xmax = n_classes
256+
self._box.ymin = 0
257+
self._box.ymax = n_classes
258+
259+
260+
# Data: Multi-class classification results (e.g., sentiment analysis)
261+
np.random.seed(42)
262+
263+
# Class names for a sentiment analysis model
264+
class_names = ["Positive", "Neutral", "Negative", "Mixed"]
265+
n_classes = len(class_names)
266+
267+
# Create realistic confusion matrix with:
268+
# - High values on diagonal (correct predictions)
269+
# - Common misclassifications (Neutral confused with Mixed, etc.)
270+
confusion_matrix = [
271+
[142, 12, 5, 8], # True Positive: mostly correct, some confused with Neutral
272+
[18, 98, 15, 22], # True Neutral: often confused with others
273+
[7, 9, 125, 11], # True Negative: mostly correct
274+
[14, 28, 18, 89], # True Mixed: hardest to classify, often confused with Neutral
275+
]
276+
277+
# Custom style for 3600x3600 square canvas
278+
custom_style = Style(
279+
background="white",
280+
plot_background="white",
281+
foreground="#333333",
282+
foreground_strong="#333333",
283+
foreground_subtle="#666666",
284+
colors=("#306998",),
285+
title_font_size=72,
286+
legend_font_size=48,
287+
label_font_size=44,
288+
value_font_size=38,
289+
font_family="sans-serif",
290+
)
291+
292+
# Sequential blue colormap (low values = light, high values = dark blue)
293+
blue_colormap = [
294+
"#f7fbff", # Very light
295+
"#deebf7",
296+
"#c6dbef",
297+
"#9ecae1",
298+
"#6baed6",
299+
"#4292c6",
300+
"#2171b5",
301+
"#08519c",
302+
"#08306b", # Dark blue (Python Blue inspired)
303+
]
304+
305+
# Create confusion matrix chart
306+
chart = ConfusionMatrixChart(
307+
width=3600,
308+
height=3600,
309+
style=custom_style,
310+
title="confusion-matrix · pygal · pyplots.ai",
311+
matrix_data=confusion_matrix,
312+
class_labels=class_names,
313+
colormap=blue_colormap,
314+
show_values=True,
315+
x_axis_title="Predicted Label",
316+
y_axis_title="True Label",
317+
show_legend=False,
318+
margin=120,
319+
margin_top=200,
320+
margin_bottom=100,
321+
show_x_labels=False,
322+
show_y_labels=False,
323+
)
324+
325+
# Add a dummy series to trigger _plot (pygal requires at least one series)
326+
chart.add("", [0])
327+
328+
# Save outputs
329+
chart.render_to_file("plot.svg")
330+
chart.render_to_png("plot.png")
331+
332+
# Also save HTML for interactivity
333+
html_content = f"""<!DOCTYPE html>
334+
<html>
335+
<head>
336+
<meta charset="utf-8">
337+
<title>confusion-matrix - pygal</title>
338+
<style>
339+
body {{ margin: 0; display: flex; justify-content: center; align-items: center; min-height: 100vh; background: #f5f5f5; }}
340+
.chart {{ max-width: 100%; height: auto; }}
341+
</style>
342+
</head>
343+
<body>
344+
<figure class="chart">
345+
{chart.render(is_unicode=True)}
346+
</figure>
347+
</body>
348+
</html>
349+
"""
350+
351+
with open("plot.html", "w", encoding="utf-8") as f:
352+
f.write(html_content)
Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
library: pygal
2+
specification_id: confusion-matrix
3+
created: '2025-12-26T17:37:36Z'
4+
updated: '2025-12-26T17:45:36Z'
5+
generated_by: claude-opus-4-5-20251101
6+
workflow_run: 20526593863
7+
issue: 0
8+
python_version: 3.13.11
9+
library_version: 3.1.0
10+
preview_url: https://storage.googleapis.com/pyplots-images/plots/confusion-matrix/pygal/plot.png
11+
preview_thumb: https://storage.googleapis.com/pyplots-images/plots/confusion-matrix/pygal/plot_thumb.png
12+
preview_html: https://storage.googleapis.com/pyplots-images/plots/confusion-matrix/pygal/plot.html
13+
quality_score: 91
14+
review:
15+
strengths:
16+
- Excellent visual design with proper colorbar, cell annotations, and diagonal highlighting
17+
- Creative solution to implement confusion matrix in pygal which lacks native heatmap
18+
support
19+
- Proper contrast handling for text on varying background colors
20+
- Good use of sequential blue colormap matching the specification requirements
21+
- Realistic sentiment analysis scenario with meaningful class confusion patterns
22+
weaknesses:
23+
- Uses custom class extending pygal.graph.graph.Graph which deviates from KISS principle
24+
(though necessary for pygal)
25+
- Imports from internal module (pygal.graph.graph) rather than public API
26+
- The sys.path manipulation for import is a workaround that adds complexity

0 commit comments

Comments
 (0)