|
| 1 | +""" pyplots.ai |
| 2 | +confusion-matrix: Confusion Matrix Heatmap |
| 3 | +Library: pygal 3.1.0 | Python 3.13.11 |
| 4 | +Quality: 91/100 | Created: 2025-12-26 |
| 5 | +""" |
| 6 | + |
| 7 | +import sys |
| 8 | + |
| 9 | +import numpy as np |
| 10 | + |
| 11 | + |
| 12 | +# Temporarily remove current directory from path to avoid name collision |
| 13 | +_cwd = sys.path[0] if sys.path[0] else "." |
| 14 | +if _cwd in sys.path: |
| 15 | + sys.path.remove(_cwd) |
| 16 | + |
| 17 | +from pygal.graph.graph import Graph # noqa: E402 |
| 18 | +from pygal.style import Style # noqa: E402 |
| 19 | + |
| 20 | + |
| 21 | +# Restore path |
| 22 | +sys.path.insert(0, _cwd) |
| 23 | + |
| 24 | + |
| 25 | +class ConfusionMatrixChart(Graph): |
| 26 | + """Custom Confusion Matrix Chart for pygal - displays classification results with annotations.""" |
| 27 | + |
| 28 | + def __init__(self, *args, **kwargs): |
| 29 | + self.matrix_data = kwargs.pop("matrix_data", []) |
| 30 | + self.class_labels = kwargs.pop("class_labels", []) |
| 31 | + self.colormap = kwargs.pop("colormap", []) |
| 32 | + self.show_values = kwargs.pop("show_values", True) |
| 33 | + self.x_axis_title = kwargs.pop("x_axis_title", "Predicted Label") |
| 34 | + self.y_axis_title = kwargs.pop("y_axis_title", "True Label") |
| 35 | + super().__init__(*args, **kwargs) |
| 36 | + |
| 37 | + def _interpolate_color(self, value, min_val, max_val): |
| 38 | + """Interpolate color for smooth gradient.""" |
| 39 | + if max_val == min_val: |
| 40 | + return self.colormap[-1] |
| 41 | + |
| 42 | + # Normalize value to 0-1 range |
| 43 | + normalized = (value - min_val) / (max_val - min_val) |
| 44 | + normalized = max(0, min(1, normalized)) |
| 45 | + |
| 46 | + # Get position in colormap |
| 47 | + pos = normalized * (len(self.colormap) - 1) |
| 48 | + idx1 = int(pos) |
| 49 | + idx2 = min(idx1 + 1, len(self.colormap) - 1) |
| 50 | + frac = pos - idx1 |
| 51 | + |
| 52 | + # Interpolate between colors |
| 53 | + c1 = self.colormap[idx1] |
| 54 | + c2 = self.colormap[idx2] |
| 55 | + |
| 56 | + r1, g1, b1 = int(c1[1:3], 16), int(c1[3:5], 16), int(c1[5:7], 16) |
| 57 | + r2, g2, b2 = int(c2[1:3], 16), int(c2[3:5], 16), int(c2[5:7], 16) |
| 58 | + |
| 59 | + r = int(r1 + (r2 - r1) * frac) |
| 60 | + g = int(g1 + (g2 - g1) * frac) |
| 61 | + b = int(b1 + (b2 - b1) * frac) |
| 62 | + |
| 63 | + return f"#{r:02x}{g:02x}{b:02x}" |
| 64 | + |
| 65 | + def _get_text_color(self, bg_color): |
| 66 | + """Get contrasting text color based on background brightness.""" |
| 67 | + r, g, b = int(bg_color[1:3], 16), int(bg_color[3:5], 16), int(bg_color[5:7], 16) |
| 68 | + brightness = (r * 299 + g * 587 + b * 114) / 1000 |
| 69 | + return "#ffffff" if brightness < 140 else "#333333" |
| 70 | + |
| 71 | + def _plot(self): |
| 72 | + """Draw the confusion matrix.""" |
| 73 | + if not self.matrix_data: |
| 74 | + return |
| 75 | + |
| 76 | + n_classes = len(self.matrix_data) |
| 77 | + |
| 78 | + # Find value range |
| 79 | + all_values = [v for row in self.matrix_data for v in row] |
| 80 | + min_val = min(all_values) |
| 81 | + max_val = max(all_values) |
| 82 | + |
| 83 | + # Get plot dimensions |
| 84 | + plot_width = self.view.width |
| 85 | + plot_height = self.view.height |
| 86 | + |
| 87 | + # Calculate margins for labels |
| 88 | + label_margin_left = 400 |
| 89 | + label_margin_bottom = 350 |
| 90 | + label_margin_top = 80 |
| 91 | + label_margin_right = 280 |
| 92 | + |
| 93 | + available_width = plot_width - label_margin_left - label_margin_right |
| 94 | + available_height = plot_height - label_margin_bottom - label_margin_top |
| 95 | + |
| 96 | + # Square cells for confusion matrix |
| 97 | + cell_size = min(available_width, available_height) / n_classes * 0.92 |
| 98 | + gap = cell_size * 0.03 |
| 99 | + |
| 100 | + # Calculate grid dimensions |
| 101 | + grid_size = n_classes * (cell_size + gap) - gap |
| 102 | + |
| 103 | + # Center the grid |
| 104 | + x_offset = self.view.x(0) + label_margin_left + (available_width - grid_size) / 2 |
| 105 | + y_offset = self.view.y(n_classes) + label_margin_top + (available_height - grid_size) / 2 |
| 106 | + |
| 107 | + # Create group for the chart |
| 108 | + plot_node = self.nodes["plot"] |
| 109 | + cm_group = self.svg.node(plot_node, class_="confusion-matrix") |
| 110 | + |
| 111 | + # Draw Y-axis title (True Label) - rotated |
| 112 | + y_title_size = 48 |
| 113 | + y_title_x = x_offset - 320 |
| 114 | + y_title_y = y_offset + grid_size / 2 |
| 115 | + text_node = self.svg.node(cm_group, "text", x=y_title_x, y=y_title_y) |
| 116 | + text_node.set("text-anchor", "middle") |
| 117 | + text_node.set("fill", "#333333") |
| 118 | + text_node.set("style", f"font-size:{y_title_size}px;font-weight:bold;font-family:sans-serif") |
| 119 | + text_node.set("transform", f"rotate(-90, {y_title_x}, {y_title_y})") |
| 120 | + text_node.text = self.y_axis_title |
| 121 | + |
| 122 | + # Draw X-axis title (Predicted Label) |
| 123 | + x_title_size = 48 |
| 124 | + x_title_x = x_offset + grid_size / 2 |
| 125 | + x_title_y = y_offset + grid_size + 280 |
| 126 | + text_node = self.svg.node(cm_group, "text", x=x_title_x, y=x_title_y) |
| 127 | + text_node.set("text-anchor", "middle") |
| 128 | + text_node.set("fill", "#333333") |
| 129 | + text_node.set("style", f"font-size:{x_title_size}px;font-weight:bold;font-family:sans-serif") |
| 130 | + text_node.text = self.x_axis_title |
| 131 | + |
| 132 | + # Draw row labels (True labels) on the left |
| 133 | + row_font_size = min(44, int(cell_size * 0.45)) |
| 134 | + for i, label in enumerate(self.class_labels): |
| 135 | + y = y_offset + i * (cell_size + gap) + cell_size / 2 |
| 136 | + text_node = self.svg.node(cm_group, "text", x=x_offset - 25, y=y + row_font_size * 0.35) |
| 137 | + text_node.set("text-anchor", "end") |
| 138 | + text_node.set("fill", "#333333") |
| 139 | + text_node.set("style", f"font-size:{row_font_size}px;font-weight:600;font-family:sans-serif") |
| 140 | + text_node.text = label |
| 141 | + |
| 142 | + # Draw column labels (Predicted labels) at the bottom - rotated |
| 143 | + col_font_size = min(44, int(cell_size * 0.45)) |
| 144 | + for j, label in enumerate(self.class_labels): |
| 145 | + x = x_offset + j * (cell_size + gap) + cell_size / 2 |
| 146 | + y = y_offset + grid_size + 30 |
| 147 | + text_node = self.svg.node(cm_group, "text", x=x, y=y) |
| 148 | + text_node.set("text-anchor", "start") |
| 149 | + text_node.set("fill", "#333333") |
| 150 | + text_node.set("style", f"font-size:{col_font_size}px;font-weight:600;font-family:sans-serif") |
| 151 | + text_node.set("transform", f"rotate(45, {x}, {y})") |
| 152 | + text_node.text = label |
| 153 | + |
| 154 | + # Draw cells with values |
| 155 | + value_font_size = min(46, int(cell_size * 0.35)) |
| 156 | + for i in range(n_classes): |
| 157 | + for j in range(n_classes): |
| 158 | + value = self.matrix_data[i][j] |
| 159 | + color = self._interpolate_color(value, min_val, max_val) |
| 160 | + text_color = self._get_text_color(color) |
| 161 | + |
| 162 | + x = x_offset + j * (cell_size + gap) |
| 163 | + y = y_offset + i * (cell_size + gap) |
| 164 | + |
| 165 | + # Highlight diagonal (correct predictions) with subtle border |
| 166 | + stroke_color = "#306998" if i == j else "#ffffff" |
| 167 | + stroke_width = "4" if i == j else "2" |
| 168 | + |
| 169 | + # Draw cell rectangle |
| 170 | + rect = self.svg.node(cm_group, "rect", x=x, y=y, width=cell_size, height=cell_size, rx=6, ry=6) |
| 171 | + rect.set("fill", color) |
| 172 | + rect.set("stroke", stroke_color) |
| 173 | + rect.set("stroke-width", stroke_width) |
| 174 | + |
| 175 | + # Add value annotation |
| 176 | + if self.show_values: |
| 177 | + text_x = x + cell_size / 2 |
| 178 | + text_y = y + cell_size / 2 + value_font_size * 0.35 |
| 179 | + |
| 180 | + text_node = self.svg.node(cm_group, "text", x=text_x, y=text_y) |
| 181 | + text_node.set("text-anchor", "middle") |
| 182 | + text_node.set("fill", text_color) |
| 183 | + text_node.set("style", f"font-size:{value_font_size}px;font-weight:bold;font-family:sans-serif") |
| 184 | + text_node.text = str(int(value)) |
| 185 | + |
| 186 | + # Draw colorbar on the right |
| 187 | + colorbar_width = 55 |
| 188 | + colorbar_height = grid_size * 0.85 |
| 189 | + colorbar_x = x_offset + grid_size + 90 |
| 190 | + colorbar_y = y_offset + (grid_size - colorbar_height) / 2 |
| 191 | + |
| 192 | + # Draw gradient colorbar |
| 193 | + n_segments = 50 |
| 194 | + segment_height = colorbar_height / n_segments |
| 195 | + for seg_i in range(n_segments): |
| 196 | + seg_value = min_val + (max_val - min_val) * (n_segments - 1 - seg_i) / (n_segments - 1) |
| 197 | + seg_color = self._interpolate_color(seg_value, min_val, max_val) |
| 198 | + seg_y = colorbar_y + seg_i * segment_height |
| 199 | + |
| 200 | + self.svg.node( |
| 201 | + cm_group, "rect", x=colorbar_x, y=seg_y, width=colorbar_width, height=segment_height + 1, fill=seg_color |
| 202 | + ) |
| 203 | + |
| 204 | + # Colorbar border |
| 205 | + self.svg.node( |
| 206 | + cm_group, |
| 207 | + "rect", |
| 208 | + x=colorbar_x, |
| 209 | + y=colorbar_y, |
| 210 | + width=colorbar_width, |
| 211 | + height=colorbar_height, |
| 212 | + fill="none", |
| 213 | + stroke="#333333", |
| 214 | + ) |
| 215 | + |
| 216 | + # Colorbar labels |
| 217 | + cb_label_size = 38 |
| 218 | + # Max value |
| 219 | + text_node = self.svg.node( |
| 220 | + cm_group, "text", x=colorbar_x + colorbar_width + 15, y=colorbar_y + cb_label_size * 0.35 |
| 221 | + ) |
| 222 | + text_node.set("fill", "#333333") |
| 223 | + text_node.set("style", f"font-size:{cb_label_size}px;font-family:sans-serif") |
| 224 | + text_node.text = str(int(max_val)) |
| 225 | + |
| 226 | + # Mid value |
| 227 | + mid_y = colorbar_y + colorbar_height / 2 |
| 228 | + text_node = self.svg.node(cm_group, "text", x=colorbar_x + colorbar_width + 15, y=mid_y + cb_label_size * 0.35) |
| 229 | + text_node.set("fill", "#333333") |
| 230 | + text_node.set("style", f"font-size:{cb_label_size}px;font-family:sans-serif") |
| 231 | + text_node.text = str(int((min_val + max_val) / 2)) |
| 232 | + |
| 233 | + # Min value |
| 234 | + text_node = self.svg.node( |
| 235 | + cm_group, "text", x=colorbar_x + colorbar_width + 15, y=colorbar_y + colorbar_height + cb_label_size * 0.35 |
| 236 | + ) |
| 237 | + text_node.set("fill", "#333333") |
| 238 | + text_node.set("style", f"font-size:{cb_label_size}px;font-family:sans-serif") |
| 239 | + text_node.text = str(int(min_val)) |
| 240 | + |
| 241 | + # Colorbar title |
| 242 | + cb_title_size = 42 |
| 243 | + cb_title_x = colorbar_x + colorbar_width / 2 |
| 244 | + cb_title_y = colorbar_y - 35 |
| 245 | + text_node = self.svg.node(cm_group, "text", x=cb_title_x, y=cb_title_y) |
| 246 | + text_node.set("text-anchor", "middle") |
| 247 | + text_node.set("fill", "#333333") |
| 248 | + text_node.set("style", f"font-size:{cb_title_size}px;font-weight:bold;font-family:sans-serif") |
| 249 | + text_node.text = "Count" |
| 250 | + |
| 251 | + def _compute(self): |
| 252 | + """Compute the box for rendering.""" |
| 253 | + n_classes = len(self.matrix_data) if self.matrix_data else 1 |
| 254 | + self._box.xmin = 0 |
| 255 | + self._box.xmax = n_classes |
| 256 | + self._box.ymin = 0 |
| 257 | + self._box.ymax = n_classes |
| 258 | + |
| 259 | + |
| 260 | +# Data: Multi-class classification results (e.g., sentiment analysis) |
| 261 | +np.random.seed(42) |
| 262 | + |
| 263 | +# Class names for a sentiment analysis model |
| 264 | +class_names = ["Positive", "Neutral", "Negative", "Mixed"] |
| 265 | +n_classes = len(class_names) |
| 266 | + |
| 267 | +# Create realistic confusion matrix with: |
| 268 | +# - High values on diagonal (correct predictions) |
| 269 | +# - Common misclassifications (Neutral confused with Mixed, etc.) |
| 270 | +confusion_matrix = [ |
| 271 | + [142, 12, 5, 8], # True Positive: mostly correct, some confused with Neutral |
| 272 | + [18, 98, 15, 22], # True Neutral: often confused with others |
| 273 | + [7, 9, 125, 11], # True Negative: mostly correct |
| 274 | + [14, 28, 18, 89], # True Mixed: hardest to classify, often confused with Neutral |
| 275 | +] |
| 276 | + |
| 277 | +# Custom style for 3600x3600 square canvas |
| 278 | +custom_style = Style( |
| 279 | + background="white", |
| 280 | + plot_background="white", |
| 281 | + foreground="#333333", |
| 282 | + foreground_strong="#333333", |
| 283 | + foreground_subtle="#666666", |
| 284 | + colors=("#306998",), |
| 285 | + title_font_size=72, |
| 286 | + legend_font_size=48, |
| 287 | + label_font_size=44, |
| 288 | + value_font_size=38, |
| 289 | + font_family="sans-serif", |
| 290 | +) |
| 291 | + |
| 292 | +# Sequential blue colormap (low values = light, high values = dark blue) |
| 293 | +blue_colormap = [ |
| 294 | + "#f7fbff", # Very light |
| 295 | + "#deebf7", |
| 296 | + "#c6dbef", |
| 297 | + "#9ecae1", |
| 298 | + "#6baed6", |
| 299 | + "#4292c6", |
| 300 | + "#2171b5", |
| 301 | + "#08519c", |
| 302 | + "#08306b", # Dark blue (Python Blue inspired) |
| 303 | +] |
| 304 | + |
| 305 | +# Create confusion matrix chart |
| 306 | +chart = ConfusionMatrixChart( |
| 307 | + width=3600, |
| 308 | + height=3600, |
| 309 | + style=custom_style, |
| 310 | + title="confusion-matrix · pygal · pyplots.ai", |
| 311 | + matrix_data=confusion_matrix, |
| 312 | + class_labels=class_names, |
| 313 | + colormap=blue_colormap, |
| 314 | + show_values=True, |
| 315 | + x_axis_title="Predicted Label", |
| 316 | + y_axis_title="True Label", |
| 317 | + show_legend=False, |
| 318 | + margin=120, |
| 319 | + margin_top=200, |
| 320 | + margin_bottom=100, |
| 321 | + show_x_labels=False, |
| 322 | + show_y_labels=False, |
| 323 | +) |
| 324 | + |
| 325 | +# Add a dummy series to trigger _plot (pygal requires at least one series) |
| 326 | +chart.add("", [0]) |
| 327 | + |
| 328 | +# Save outputs |
| 329 | +chart.render_to_file("plot.svg") |
| 330 | +chart.render_to_png("plot.png") |
| 331 | + |
| 332 | +# Also save HTML for interactivity |
| 333 | +html_content = f"""<!DOCTYPE html> |
| 334 | +<html> |
| 335 | +<head> |
| 336 | + <meta charset="utf-8"> |
| 337 | + <title>confusion-matrix - pygal</title> |
| 338 | + <style> |
| 339 | + body {{ margin: 0; display: flex; justify-content: center; align-items: center; min-height: 100vh; background: #f5f5f5; }} |
| 340 | + .chart {{ max-width: 100%; height: auto; }} |
| 341 | + </style> |
| 342 | +</head> |
| 343 | +<body> |
| 344 | + <figure class="chart"> |
| 345 | + {chart.render(is_unicode=True)} |
| 346 | + </figure> |
| 347 | +</body> |
| 348 | +</html> |
| 349 | +""" |
| 350 | + |
| 351 | +with open("plot.html", "w", encoding="utf-8") as f: |
| 352 | + f.write(html_content) |
0 commit comments